mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 18:26:41 +00:00
[management] Add API of new network concept (#3012)
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,9 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// AccountRequest holds the result channel to return the requested account.
|
||||
@@ -17,19 +20,19 @@ type AccountRequest struct {
|
||||
|
||||
// AccountResult holds the account data or an error.
|
||||
type AccountResult struct {
|
||||
Account *Account
|
||||
Account *types.Account
|
||||
Err error
|
||||
}
|
||||
|
||||
type AccountRequestBuffer struct {
|
||||
store Store
|
||||
store store.Store
|
||||
getAccountRequests map[string][]*AccountRequest
|
||||
mu sync.Mutex
|
||||
getAccountRequestCh chan *AccountRequest
|
||||
bufferInterval time.Duration
|
||||
}
|
||||
|
||||
func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBuffer {
|
||||
func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountRequestBuffer {
|
||||
bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL")
|
||||
bufferInterval, err := time.ParseDuration(bufferIntervalStr)
|
||||
if err != nil {
|
||||
@@ -52,7 +55,7 @@ func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBu
|
||||
|
||||
return &ac
|
||||
}
|
||||
func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) {
|
||||
func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) {
|
||||
req := &AccountRequest{
|
||||
AccountID: accountID,
|
||||
ResultChan: make(chan *AccountResult, 1),
|
||||
|
||||
@@ -16,7 +16,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -29,7 +33,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -74,7 +80,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)
|
||||
func (MocIntegratedValidator) Stop(_ context.Context) {
|
||||
}
|
||||
|
||||
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) {
|
||||
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) {
|
||||
t.Helper()
|
||||
peer := &nbpeer.Peer{
|
||||
Key: "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=",
|
||||
@@ -102,7 +108,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac
|
||||
}
|
||||
}
|
||||
|
||||
func verifyNewAccountHasDefaultFields(t *testing.T, account *Account, createdBy string, domain string, expectedUsers []string) {
|
||||
func verifyNewAccountHasDefaultFields(t *testing.T, account *types.Account, createdBy string, domain string, expectedUsers []string) {
|
||||
t.Helper()
|
||||
if len(account.Peers) != 0 {
|
||||
t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers))
|
||||
@@ -157,7 +163,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
// peerID3 := "peer-3"
|
||||
tt := []struct {
|
||||
name string
|
||||
accountSettings Settings
|
||||
accountSettings types.Settings
|
||||
peerID string
|
||||
expectedPeers []string
|
||||
expectedOfflinePeers []string
|
||||
@@ -165,7 +171,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "Should return ALL peers when global peer login expiration disabled",
|
||||
accountSettings: Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour},
|
||||
accountSettings: types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour},
|
||||
peerID: peerID1,
|
||||
expectedPeers: []string{peerID2},
|
||||
expectedOfflinePeers: []string{},
|
||||
@@ -203,7 +209,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Should return no peers when global peer login expiration enabled and peers expired",
|
||||
accountSettings: Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour},
|
||||
accountSettings: types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour},
|
||||
peerID: peerID1,
|
||||
expectedPeers: []string{},
|
||||
expectedOfflinePeers: []string{peerID2},
|
||||
@@ -397,12 +403,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
|
||||
netIP := net.IP{100, 64, 0, 0}
|
||||
netMask := net.IPMask{255, 255, 0, 0}
|
||||
network := &Network{
|
||||
network := &types.Network{
|
||||
Identifier: "network",
|
||||
Net: net.IPNet{IP: netIP, Mask: netMask},
|
||||
Dns: "netbird.selfhosted",
|
||||
Serial: 0,
|
||||
mu: sync.Mutex{},
|
||||
Mu: sync.Mutex{},
|
||||
}
|
||||
|
||||
for _, testCase := range tt {
|
||||
@@ -486,12 +492,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
}
|
||||
|
||||
initUnknown := defaultInitAccount
|
||||
initUnknown.DomainCategory = UnknownCategory
|
||||
initUnknown.DomainCategory = types.UnknownCategory
|
||||
initUnknown.Domain = unknownDomain
|
||||
|
||||
privateInitAccount := defaultInitAccount
|
||||
privateInitAccount.Domain = privateDomain
|
||||
privateInitAccount.DomainCategory = PrivateCategory
|
||||
privateInitAccount.DomainCategory = types.PrivateCategory
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -501,7 +507,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
inputUpdateClaimAccount bool
|
||||
testingFunc require.ComparisonAssertionFunc
|
||||
expectedMSG string
|
||||
expectedUserRole UserRole
|
||||
expectedUserRole types.UserRole
|
||||
expectedDomainCategory string
|
||||
expectedDomain string
|
||||
expectedPrimaryDomainStatus bool
|
||||
@@ -513,12 +519,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: publicDomain,
|
||||
UserId: "pub-domain-user",
|
||||
DomainCategory: PublicCategory,
|
||||
DomainCategory: types.PublicCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedUserRole: types.UserRoleOwner,
|
||||
expectedDomainCategory: "",
|
||||
expectedDomain: publicDomain,
|
||||
expectedPrimaryDomainStatus: false,
|
||||
@@ -530,12 +536,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: unknownDomain,
|
||||
UserId: "unknown-domain-user",
|
||||
DomainCategory: UnknownCategory,
|
||||
DomainCategory: types.UnknownCategory,
|
||||
},
|
||||
inputInitUserParams: initUnknown,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedUserRole: types.UserRoleOwner,
|
||||
expectedDomain: unknownDomain,
|
||||
expectedDomainCategory: "",
|
||||
expectedPrimaryDomainStatus: false,
|
||||
@@ -547,14 +553,14 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: privateDomain,
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
DomainCategory: types.PrivateCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedUserRole: types.UserRoleOwner,
|
||||
expectedDomain: privateDomain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedDomainCategory: types.PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: "pvt-domain-user",
|
||||
expectedUsers: []string{"pvt-domain-user"},
|
||||
@@ -564,15 +570,15 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: privateDomain,
|
||||
UserId: "new-pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
DomainCategory: types.PrivateCategory,
|
||||
},
|
||||
inputUpdateAttrs: true,
|
||||
inputInitUserParams: privateInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleUser,
|
||||
expectedUserRole: types.UserRoleUser,
|
||||
expectedDomain: privateDomain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedDomainCategory: types.PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
|
||||
@@ -582,14 +588,14 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: PrivateCategory,
|
||||
DomainCategory: types.PrivateCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedUserRole: types.UserRoleOwner,
|
||||
expectedDomain: defaultInitAccount.Domain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedDomainCategory: types.PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId},
|
||||
@@ -599,15 +605,15 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: PrivateCategory,
|
||||
DomainCategory: types.PrivateCategory,
|
||||
},
|
||||
inputUpdateClaimAccount: true,
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedUserRole: types.UserRoleOwner,
|
||||
expectedDomain: defaultInitAccount.Domain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedDomainCategory: types.PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId},
|
||||
@@ -617,12 +623,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: "",
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
DomainCategory: types.PrivateCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedUserRole: types.UserRoleOwner,
|
||||
expectedDomain: "",
|
||||
expectedDomainCategory: "",
|
||||
expectedPrimaryDomainStatus: false,
|
||||
@@ -752,22 +758,26 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account.Users["someUser"] = &User{
|
||||
account.Users["someUser"] = &types.User{
|
||||
Id: "someUser",
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
HashedToken: encodedHashedToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -787,15 +797,20 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
store := newStore(t)
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account.Users["someUser"] = &User{
|
||||
account.Users["someUser"] = &types.User{
|
||||
Id: "someUser",
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
HashedToken: encodedHashedToken,
|
||||
@@ -803,7 +818,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -905,7 +920,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
|
||||
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists, "expected to get existing account after creation using userid")
|
||||
|
||||
@@ -915,7 +930,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) {
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -991,13 +1006,13 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
Domain: "example.com",
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
DomainCategory: types.PrivateCategory,
|
||||
}
|
||||
|
||||
publicClaims := jwtclaims.AuthorizationClaims{
|
||||
Domain: "test.com",
|
||||
UserId: "public-domain-user",
|
||||
DomainCategory: PublicCategory,
|
||||
DomainCategory: types.PublicCategory,
|
||||
}
|
||||
|
||||
am, err := createManager(b)
|
||||
@@ -1075,13 +1090,13 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||
|
||||
}
|
||||
|
||||
func genUsers(p string, n int) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
func genUsers(p string, n int) map[string]*types.User {
|
||||
users := map[string]*types.User{}
|
||||
now := time.Now()
|
||||
for i := 0; i < n; i++ {
|
||||
users[fmt.Sprintf("%s-%d", p, i)] = &User{
|
||||
users[fmt.Sprintf("%s-%d", p, i)] = &types.User{
|
||||
Id: fmt.Sprintf("%s-%d", p, i),
|
||||
Role: UserRoleAdmin,
|
||||
Role: types.UserRoleAdmin,
|
||||
LastLogin: now,
|
||||
CreatedAt: now,
|
||||
Issued: "api",
|
||||
@@ -1106,7 +1121,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
|
||||
serial := account.Network.CurrentSerial() // should be 0
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@@ -1243,15 +1258,15 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -1335,15 +1350,15 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -1368,15 +1383,15 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -1427,15 +1442,15 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -1483,7 +1498,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@@ -1558,7 +1573,7 @@ func TestGetUsersFromAccount(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
users := map[string]*User{"1": {Id: "1", Role: UserRoleOwner}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}}
|
||||
users := map[string]*types.User{"1": {Id: "1", Role: types.UserRoleOwner}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}}
|
||||
accountId := "test_account_id"
|
||||
|
||||
account, err := createAccount(manager, accountId, users["1"].Id, "")
|
||||
@@ -1590,7 +1605,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Routes: map[route.ID]*route.Route{
|
||||
"route-1": {
|
||||
ID: "route-1",
|
||||
@@ -1637,7 +1652,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
|
||||
},
|
||||
@@ -1682,7 +1697,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
||||
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
||||
|
||||
assert.Len(t, routes, 2)
|
||||
routeIDs := make(map[route.ID]struct{}, 2)
|
||||
@@ -1692,26 +1707,26 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||
|
||||
emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
||||
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
||||
|
||||
assert.Len(t, emptyRoutes, 0)
|
||||
}
|
||||
|
||||
func TestAccount_Copy(t *testing.T) {
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Id: "account1",
|
||||
CreatedBy: "tester",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Domain: "test.com",
|
||||
DomainCategory: "public",
|
||||
IsDomainPrimaryAccount: true,
|
||||
SetupKeys: map[string]*SetupKey{
|
||||
SetupKeys: map[string]*types.SetupKey{
|
||||
"setup1": {
|
||||
Id: "setup1",
|
||||
AutoGroups: []string{"group1"},
|
||||
},
|
||||
},
|
||||
Network: &Network{
|
||||
Network: &types.Network{
|
||||
Identifier: "net1",
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
@@ -1724,12 +1739,12 @@ func TestAccount_Copy(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
Users: map[string]*User{
|
||||
Users: map[string]*types.User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
Role: UserRoleAdmin,
|
||||
Role: types.UserRoleAdmin,
|
||||
AutoGroups: []string{"group1"},
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"pat1": {
|
||||
ID: "pat1",
|
||||
Name: "First PAT",
|
||||
@@ -1748,11 +1763,11 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Peers: []string{"peer1"},
|
||||
},
|
||||
},
|
||||
Policies: []*Policy{
|
||||
Policies: []*types.Policy{
|
||||
{
|
||||
ID: "policy1",
|
||||
Enabled: true,
|
||||
Rules: make([]*PolicyRule, 0),
|
||||
Rules: make([]*types.PolicyRule, 0),
|
||||
SourcePostureChecks: make([]string, 0),
|
||||
},
|
||||
},
|
||||
@@ -1772,19 +1787,19 @@ func TestAccount_Copy(t *testing.T) {
|
||||
NameServers: []nbdns.NameServer{},
|
||||
},
|
||||
},
|
||||
DNSSettings: DNSSettings{DisabledManagementGroups: []string{}},
|
||||
DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{}},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{
|
||||
ID: "posture Checks1",
|
||||
},
|
||||
},
|
||||
Settings: &Settings{},
|
||||
Networks: []*networks.Network{
|
||||
Settings: &types.Settings{},
|
||||
Networks: []*networkTypes.Network{
|
||||
{
|
||||
ID: "network1",
|
||||
},
|
||||
},
|
||||
NetworkRouters: []*networks.NetworkRouter{
|
||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||
{
|
||||
ID: "router1",
|
||||
NetworkID: "network1",
|
||||
@@ -1793,7 +1808,7 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Metric: 0,
|
||||
},
|
||||
},
|
||||
NetworkResources: []*networks.NetworkResource{
|
||||
NetworkResources: []*resourceTypes.NetworkResource{
|
||||
{
|
||||
ID: "resource1",
|
||||
NetworkID: "network1",
|
||||
@@ -1854,7 +1869,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
assert.NotNil(t, settings)
|
||||
@@ -1887,7 +1902,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@@ -1935,7 +1950,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@@ -2004,7 +2019,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
},
|
||||
}
|
||||
// enabling PeerLoginExpirationEnabled should trigger the expiration job
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@@ -2017,7 +2032,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
wg.Add(1)
|
||||
|
||||
// disabling PeerLoginExpirationEnabled should trigger cancel
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
@@ -2035,7 +2050,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
@@ -2043,19 +2058,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
assert.False(t, settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Second,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour * 24 * 181,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
@@ -2128,9 +2143,9 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Peers: testCase.peers,
|
||||
Settings: &Settings{
|
||||
Settings: &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
},
|
||||
@@ -2212,9 +2227,9 @@ func TestAccount_GetInactivePeers(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Peers: testCase.peers,
|
||||
Settings: &Settings{
|
||||
Settings: &types.Settings{
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Second,
|
||||
},
|
||||
@@ -2279,7 +2294,7 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Peers: testCase.peers,
|
||||
}
|
||||
|
||||
@@ -2348,7 +2363,7 @@ func TestAccount_GetPeersWithInactivity(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Peers: testCase.peers,
|
||||
}
|
||||
|
||||
@@ -2512,9 +2527,9 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Peers: testCase.peers,
|
||||
Settings: &Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled},
|
||||
Settings: &types.Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled},
|
||||
}
|
||||
|
||||
expiration, ok := account.GetNextPeerExpiration()
|
||||
@@ -2672,9 +2687,9 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Peers: testCase.peers,
|
||||
Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled},
|
||||
Settings: &types.Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled},
|
||||
}
|
||||
|
||||
expiration, ok := account.GetNextInactivePeerExpiration()
|
||||
@@ -2693,7 +2708,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
// create a new account
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Id: "accountID",
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
|
||||
@@ -2705,8 +2720,8 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
Groups: map[string]*group.Group{
|
||||
"group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}},
|
||||
},
|
||||
Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"},
|
||||
Users: map[string]*User{
|
||||
Settings: &types.Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"},
|
||||
Users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "accountID"},
|
||||
"user2": {Id: "user2", AccountID: "accountID"},
|
||||
},
|
||||
@@ -2722,7 +2737,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
|
||||
})
|
||||
@@ -2735,18 +2750,18 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
|
||||
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
|
||||
account.Users["user1"].AutoGroups = []string{"group1"}
|
||||
assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"]))
|
||||
assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"]))
|
||||
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user1",
|
||||
@@ -2755,11 +2770,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2772,7 +2787,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
||||
})
|
||||
@@ -2785,7 +2800,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
||||
})
|
||||
@@ -2798,11 +2813,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, groups, 3, "new group3 should be added")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||
})
|
||||
@@ -2815,7 +2830,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
|
||||
assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present")
|
||||
@@ -2823,7 +2838,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccount_UserGroupsAddToPeers(t *testing.T) {
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
|
||||
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
|
||||
@@ -2836,7 +2851,7 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) {
|
||||
"group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{}},
|
||||
"group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{}},
|
||||
},
|
||||
Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}},
|
||||
Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}},
|
||||
}
|
||||
|
||||
t.Run("add groups", func(t *testing.T) {
|
||||
@@ -2859,7 +2874,7 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
|
||||
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
|
||||
@@ -2872,7 +2887,7 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
|
||||
"group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}},
|
||||
"group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}},
|
||||
},
|
||||
Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}},
|
||||
Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}},
|
||||
}
|
||||
|
||||
t.Run("remove groups", func(t *testing.T) {
|
||||
@@ -2915,10 +2930,10 @@ func createManager(t TB) (*DefaultAccountManager, error) {
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func createStore(t TB) (Store, error) {
|
||||
func createStore(t TB) (store.Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2941,7 +2956,7 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
|
||||
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
|
||||
t.Helper()
|
||||
|
||||
manager, err := createManager(t)
|
||||
@@ -2954,12 +2969,12 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpee
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
}
|
||||
|
||||
getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer {
|
||||
getPeer := func(manager *DefaultAccountManager, setupKey *types.SetupKey) *nbpeer.Peer {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -156,7 +157,7 @@ type ProviderConfig struct {
|
||||
|
||||
// StoreConfig contains Store configuration
|
||||
type StoreConfig struct {
|
||||
Engine StoreEngine
|
||||
Engine store.Engine
|
||||
}
|
||||
|
||||
// ReverseProxy contains reverse proxy configuration in front of management.
|
||||
|
||||
@@ -2,9 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -12,12 +10,12 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
)
|
||||
|
||||
const defaultTTL = 300
|
||||
|
||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||
type DNSConfigCache struct {
|
||||
CustomZones sync.Map
|
||||
@@ -62,26 +60,9 @@ func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerG
|
||||
c.NameServerGroups.Store(key, value)
|
||||
}
|
||||
|
||||
type lookupMap map[string]struct{}
|
||||
|
||||
// DNSSettings defines dns settings at the account level
|
||||
type DNSSettings struct {
|
||||
// DisabledManagementGroups groups whose DNS management is disabled
|
||||
DisabledManagementGroups []string `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// Copy returns a copy of the DNS settings
|
||||
func (d DNSSettings) Copy() DNSSettings {
|
||||
settings := DNSSettings{
|
||||
DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)),
|
||||
}
|
||||
copy(settings.DisabledManagementGroups, d.DisabledManagementGroups)
|
||||
return settings
|
||||
}
|
||||
|
||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -94,16 +75,16 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error {
|
||||
if dnsSettingsToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -119,18 +100,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
|
||||
oldSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
|
||||
if err != nil {
|
||||
@@ -140,11 +121,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
|
||||
return transaction.SaveDNSSettings(ctx, store.LockingStrengthUpdate, accountID, dnsSettingsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -162,11 +143,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
// prepareDNSSettingsEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
|
||||
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction store.Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
|
||||
return nil
|
||||
@@ -203,7 +184,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
|
||||
}
|
||||
|
||||
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
|
||||
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -217,12 +198,12 @@ func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, acc
|
||||
}
|
||||
|
||||
// validateDNSSettings validates the DNS settings.
|
||||
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
|
||||
func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error {
|
||||
if len(settings.DisabledManagementGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -298,81 +279,3 @@ func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameSe
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
|
||||
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
||||
groupList := account.getPeerGroups(peerID)
|
||||
|
||||
var peerNSGroups []*nbdns.NameServerGroup
|
||||
|
||||
for _, nsGroup := range account.NameServerGroups {
|
||||
if !nsGroup.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, gID := range nsGroup.Groups {
|
||||
_, found := groupList[gID]
|
||||
if found {
|
||||
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
|
||||
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peerNSGroups
|
||||
}
|
||||
|
||||
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
|
||||
func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if peer.IP.Equal(ns.IP.AsSlice()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) {
|
||||
for _, peer := range account.Peers {
|
||||
label, err := getPeerHostLabel(peer.Name, peerLabels)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
|
||||
label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
peer.DNSLabel = label
|
||||
peerLabels[label] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func getPeerHostLabel(name string, peerLabels lookupMap) (string, error) {
|
||||
label, err := nbdns.GetParsedDomainLabel(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
uniqueLabel := getUniqueHostLabel(label, peerLabels)
|
||||
if uniqueLabel == "" {
|
||||
return "", fmt.Errorf("couldn't find a unique valid label for %s, parsed label %s", name, label)
|
||||
}
|
||||
return uniqueLabel, nil
|
||||
}
|
||||
|
||||
// getUniqueHostLabel look for a unique host label, and if doesn't find add a suffix up to 999
|
||||
func getUniqueHostLabel(name string, peerLabels lookupMap) string {
|
||||
_, found := peerLabels[name]
|
||||
if !found {
|
||||
return name
|
||||
}
|
||||
for i := 1; i < 1000; i++ {
|
||||
nameWithSuffix := name + "-" + strconv.Itoa(i)
|
||||
_, found = peerLabels[nameWithSuffix]
|
||||
if !found {
|
||||
return nameWithSuffix
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -11,7 +11,9 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -53,7 +55,7 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
t.Fatal("DNS settings for new accounts shouldn't return nil")
|
||||
}
|
||||
|
||||
account.DNSSettings = DNSSettings{
|
||||
account.DNSSettings = types.DNSSettings{
|
||||
DisabledManagementGroups: []string{group1ID},
|
||||
}
|
||||
|
||||
@@ -86,20 +88,20 @@ func TestSaveDNSSettings(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
userID string
|
||||
inputSettings *DNSSettings
|
||||
inputSettings *types.DNSSettings
|
||||
shouldFail bool
|
||||
}{
|
||||
{
|
||||
name: "Saving As Admin Should Be OK",
|
||||
userID: dnsAdminUserID,
|
||||
inputSettings: &DNSSettings{
|
||||
inputSettings: &types.DNSSettings{
|
||||
DisabledManagementGroups: []string{dnsGroup1ID},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Should Not Update Settings As Regular User",
|
||||
userID: dnsRegularUserID,
|
||||
inputSettings: &DNSSettings{
|
||||
inputSettings: &types.DNSSettings{
|
||||
DisabledManagementGroups: []string{dnsGroup1ID},
|
||||
},
|
||||
shouldFail: true,
|
||||
@@ -113,7 +115,7 @@ func TestSaveDNSSettings(t *testing.T) {
|
||||
{
|
||||
name: "Should Not Update Settings If Group Is Invalid",
|
||||
userID: dnsAdminUserID,
|
||||
inputSettings: &DNSSettings{
|
||||
inputSettings: &types.DNSSettings{
|
||||
DisabledManagementGroups: []string{"non-existing-group"},
|
||||
},
|
||||
shouldFail: true,
|
||||
@@ -210,10 +212,10 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (Store, error) {
|
||||
func createDNSStore(t *testing.T) (store.Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -222,7 +224,7 @@ func createDNSStore(t *testing.T) (Store, error) {
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
|
||||
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) {
|
||||
t.Helper()
|
||||
peer1 := &nbpeer.Peer{
|
||||
Key: dnsPeer1Key,
|
||||
@@ -259,9 +261,9 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
|
||||
|
||||
account.Users[dnsRegularUserID] = &User{
|
||||
account.Users[dnsRegularUserID] = &types.User{
|
||||
Id: dnsRegularUserID,
|
||||
Role: UserRoleUser,
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
@@ -510,7 +512,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
@@ -589,7 +591,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA", "groupB"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
@@ -609,7 +611,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
@@ -629,7 +631,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{
|
||||
DisabledManagementGroups: []string{},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -21,7 +23,7 @@ var (
|
||||
|
||||
type ephemeralPeer struct {
|
||||
id string
|
||||
account *Account
|
||||
account *types.Account
|
||||
deadline time.Time
|
||||
next *ephemeralPeer
|
||||
}
|
||||
@@ -32,7 +34,7 @@ type ephemeralPeer struct {
|
||||
// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted
|
||||
// automatically. Inactivity means the peer disconnected from the Management server.
|
||||
type EphemeralManager struct {
|
||||
store Store
|
||||
store store.Store
|
||||
accountManager AccountManager
|
||||
|
||||
headPeer *ephemeralPeer
|
||||
@@ -42,7 +44,7 @@ type EphemeralManager struct {
|
||||
}
|
||||
|
||||
// NewEphemeralManager instantiate new EphemeralManager
|
||||
func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralManager {
|
||||
func NewEphemeralManager(store store.Store, accountManager AccountManager) *EphemeralManager {
|
||||
return &EphemeralManager{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
@@ -177,7 +179,7 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) {
|
||||
func (e *EphemeralManager) addPeer(id string, account *types.Account, deadline time.Time) {
|
||||
ep := &ephemeralPeer{
|
||||
id: id,
|
||||
account: account,
|
||||
|
||||
@@ -8,18 +8,20 @@ import (
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type MockStore struct {
|
||||
Store
|
||||
account *Account
|
||||
store.Store
|
||||
account *types.Account
|
||||
}
|
||||
|
||||
func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
|
||||
return []*Account{s.account}
|
||||
func (s *MockStore) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
return []*types.Account{s.account}
|
||||
}
|
||||
|
||||
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
|
||||
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*types.Account, error) {
|
||||
_, ok := s.account.Peers[peerId]
|
||||
if ok {
|
||||
return s.account, nil
|
||||
|
||||
@@ -10,6 +10,9 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -28,7 +31,7 @@ func (e *GroupLinkError) Error() string {
|
||||
|
||||
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
||||
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -49,7 +52,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
|
||||
}
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
@@ -57,12 +60,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
|
||||
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName)
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
@@ -76,7 +79,7 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -93,7 +96,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
var groupsToSave []*nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
@@ -113,11 +116,11 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
|
||||
return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, groupsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -135,16 +138,16 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
addedPeers := make([]string, 0)
|
||||
removedPeers := make([]string, 0)
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID)
|
||||
if err == nil && oldGroup != nil {
|
||||
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
|
||||
} else {
|
||||
addedPeers = append(addedPeers, newGroup.Peers...)
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
@@ -153,7 +156,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
||||
}
|
||||
|
||||
modifiedPeers := slices.Concat(addedPeers, removedPeers)
|
||||
peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
|
||||
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
|
||||
return nil
|
||||
@@ -194,21 +197,6 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
// difference returns the elements in `a` that aren't in `b`.
|
||||
func difference(a, b []string) []string {
|
||||
mb := make(map[string]struct{}, len(b))
|
||||
for _, x := range b {
|
||||
mb[x] = struct{}{}
|
||||
}
|
||||
var diff []string
|
||||
for _, x := range a {
|
||||
if _, found := mb[x]; !found {
|
||||
diff = append(diff, x)
|
||||
}
|
||||
}
|
||||
return diff
|
||||
}
|
||||
|
||||
// DeleteGroup object of the peers.
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
@@ -223,7 +211,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
|
||||
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
||||
// Errors are collected and returned at the end.
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -240,9 +228,9 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
var groupIDsToDelete []string
|
||||
var deletedGroups []*nbgroup.Group
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
@@ -257,11 +245,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
|
||||
return transaction.DeleteGroups(ctx, store.LockingStrengthUpdate, accountID, groupIDsToDelete)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -283,8 +271,8 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -298,11 +286,11 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -324,8 +312,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -339,11 +327,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -357,13 +345,13 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
// validateNewGroup validates the new group for existence and required fields.
|
||||
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
|
||||
func validateNewGroup(ctx context.Context, transaction store.Store, accountID string, newGroup *nbgroup.Group) error {
|
||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
|
||||
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
|
||||
return err
|
||||
@@ -380,7 +368,7 @@ func validateNewGroup(ctx context.Context, transaction Store, accountID string,
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
@@ -389,14 +377,14 @@ func validateNewGroup(ctx context.Context, transaction Store, accountID string,
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
|
||||
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *nbgroup.Group, userID string) error {
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if group.Issued == nbgroup.GroupIssuedIntegration {
|
||||
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||
if executingUser.Role != types.UserRoleAdmin || !executingUser.IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
||||
}
|
||||
}
|
||||
@@ -429,8 +417,8 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.
|
||||
}
|
||||
|
||||
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
|
||||
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *nbgroup.Group) error {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -439,7 +427,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *n
|
||||
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
||||
}
|
||||
|
||||
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -452,8 +440,8 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *n
|
||||
}
|
||||
|
||||
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
||||
func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -469,8 +457,8 @@ func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID stri
|
||||
}
|
||||
|
||||
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
||||
func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -487,8 +475,8 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID str
|
||||
}
|
||||
|
||||
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
||||
func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -506,8 +494,8 @@ func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string
|
||||
}
|
||||
|
||||
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
||||
func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
|
||||
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
|
||||
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -522,8 +510,8 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID s
|
||||
}
|
||||
|
||||
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
||||
func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
|
||||
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
|
||||
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
@@ -538,12 +526,12 @@ func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID strin
|
||||
}
|
||||
|
||||
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
|
||||
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||
func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -566,7 +554,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||
func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||
return true
|
||||
@@ -576,8 +564,8 @@ func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []s
|
||||
}
|
||||
|
||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
|
||||
func anyGroupHasPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -267,7 +268,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *Account, error) {
|
||||
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) {
|
||||
accountID := "testingAcc"
|
||||
domain := "example.com"
|
||||
|
||||
@@ -342,9 +343,9 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
|
||||
Groups: []string{groupForNameServerGroups.ID},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
policy := &types.Policy{
|
||||
ID: "example policy",
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "example policy rule",
|
||||
Destinations: []string{groupForPolicies.ID},
|
||||
@@ -352,12 +353,12 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
|
||||
},
|
||||
}
|
||||
|
||||
setupKey := &SetupKey{
|
||||
setupKey := &types.SetupKey{
|
||||
Id: "example setup key",
|
||||
AutoGroups: []string{groupForSetupKeys.ID},
|
||||
}
|
||||
|
||||
user := &User{
|
||||
user := &types.User{
|
||||
Id: "example user",
|
||||
AutoGroups: []string{groupForUsers.ID},
|
||||
}
|
||||
@@ -500,15 +501,15 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
|
||||
// adding a group to policy
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -648,7 +649,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// Saving a group linked to dns settings should update account peers and send peer update
|
||||
t.Run("saving group linked to dns settings", func(t *testing.T) {
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupD"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
internalStatus "github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// GRPCServer an instance of a Management gRPC API server
|
||||
@@ -599,7 +600,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *Token, relayToken *Tok
|
||||
}
|
||||
}
|
||||
|
||||
func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.PeerConfig {
|
||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string) *proto.PeerConfig {
|
||||
netmask, _ := network.Net.Mask.Size()
|
||||
fqdn := peer.FQDN(dnsName)
|
||||
return &proto.PeerConfig{
|
||||
@@ -609,7 +610,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe
|
||||
}
|
||||
}
|
||||
|
||||
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
|
||||
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials, relayCredentials),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName),
|
||||
@@ -661,7 +662,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
|
||||
}
|
||||
|
||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
|
||||
var err error
|
||||
|
||||
var turnToken *Token
|
||||
|
||||
@@ -1176,6 +1176,105 @@ components:
|
||||
- id
|
||||
- network_type
|
||||
- $ref: '#/components/schemas/RouteRequest'
|
||||
NetworkRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: Network name
|
||||
type: string
|
||||
example: Remote Network 1
|
||||
description:
|
||||
description: Network description
|
||||
type: string
|
||||
example: A remote network that needs to be accessed
|
||||
required:
|
||||
- name
|
||||
Network:
|
||||
allOf:
|
||||
- type: object
|
||||
properties:
|
||||
id:
|
||||
description: Network ID
|
||||
type: string
|
||||
example: chacdk86lnnboviihd7g
|
||||
required:
|
||||
- id
|
||||
- $ref: '#/components/schemas/NetworkRequest'
|
||||
NetworkResourceRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: Network resource name
|
||||
type: string
|
||||
example: Remote Resource 1
|
||||
description:
|
||||
description: Network resource description
|
||||
type: string
|
||||
example: A remote resource inside network 1
|
||||
address:
|
||||
description: Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com)
|
||||
type: string
|
||||
example: "1.1.1.1"
|
||||
required:
|
||||
- name
|
||||
- address
|
||||
NetworkResource:
|
||||
allOf:
|
||||
- type: object
|
||||
properties:
|
||||
id:
|
||||
description: Network Resource ID
|
||||
type: string
|
||||
example: chacdk86lnnboviihd7g
|
||||
type:
|
||||
description: Network resource type based of the address
|
||||
type: string
|
||||
enum: [ "host", "subnet", "domain"]
|
||||
example: host
|
||||
required:
|
||||
- id
|
||||
- type
|
||||
- $ref: '#/components/schemas/NetworkResourceRequest'
|
||||
NetworkRouterRequest:
|
||||
type: object
|
||||
properties:
|
||||
peer:
|
||||
description: Peer Identifier associated with route. This property can not be set together with `peer_groups`
|
||||
type: string
|
||||
example: chacbco6lnnbn6cg5s91
|
||||
peer_groups:
|
||||
description: Peers Group Identifier associated with route. This property can not be set together with `peer`
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: chacbco6lnnbn6cg5s91
|
||||
metric:
|
||||
description: Route metric number. Lowest number has higher priority
|
||||
type: integer
|
||||
maximum: 9999
|
||||
minimum: 1
|
||||
example: 9999
|
||||
masquerade:
|
||||
description: Indicate if peer should masquerade traffic to this route's prefix
|
||||
type: boolean
|
||||
example: true
|
||||
required:
|
||||
# Only one property has to be set
|
||||
#- peer
|
||||
#- peer_groups
|
||||
- metric
|
||||
- masquerade
|
||||
NetworkRouter:
|
||||
allOf:
|
||||
- type: object
|
||||
properties:
|
||||
id:
|
||||
description: Network Router Id
|
||||
type: string
|
||||
example: chacdk86lnnboviihd7g
|
||||
required:
|
||||
- id
|
||||
- $ref: '#/components/schemas/NetworkRouterRequest'
|
||||
Nameserver:
|
||||
type: object
|
||||
properties:
|
||||
@@ -2460,6 +2559,502 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/networks:
|
||||
get:
|
||||
summary: List all Networks
|
||||
description: Returns a list of all networks
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Array of Networks
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Network'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
post:
|
||||
summary: Create a Network
|
||||
description: Creates a Network
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
requestBody:
|
||||
description: New Network request
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/NetworkRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A Network Object
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Network'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/networks/{networkId}:
|
||||
get:
|
||||
summary: Retrieve a Network
|
||||
description: Get information about a Network
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: networkId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network
|
||||
responses:
|
||||
'200':
|
||||
description: A Network object
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Network'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
put:
|
||||
summary: Update a Network
|
||||
description: Update/Replace a Network
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: networkId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network
|
||||
requestBody:
|
||||
description: Update Network request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/NetworkRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A Network object
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Network'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
delete:
|
||||
summary: Delete a Network
|
||||
description: Delete a network
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: networkId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network
|
||||
responses:
|
||||
'200':
|
||||
description: Delete status code
|
||||
content: { }
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/networks/{networkId}/resources:
|
||||
get:
|
||||
summary: List all Network Resources
|
||||
description: Returns a list of all resources in a network
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: networkId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Array of Resources
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/NetworkResource'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
post:
|
||||
summary: Create a Network Resource
|
||||
description: Creates a Network Resource
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: networkId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network
|
||||
requestBody:
|
||||
description: New Network Resource request
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/NetworkResourceRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A Network Resource Object
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/NetworkResource'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/networks/{networkId}/resources/{resourceId}:
|
||||
get:
|
||||
summary: Retrieve a Network Resource
|
||||
description: Get information about a Network Resource
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: networkId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network
|
||||
- in: path
|
||||
name: resourceId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network resource
|
||||
responses:
|
||||
'200':
|
||||
description: A Network Resource object
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/NetworkResource'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
put:
|
||||
summary: Update a Network Resource
|
||||
description: Update a Network Resource
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: networkId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network
|
||||
- in: path
|
||||
name: resourceId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a resource
|
||||
requestBody:
|
||||
description: Update Network Resource request
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/NetworkResourceRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A Network Resource object
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/NetworkResource'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
delete:
|
||||
summary: Delete a Network Resource
|
||||
description: Delete a network resource
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: networkId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network
|
||||
- in: path
|
||||
name: resourceId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network resource
|
||||
responses:
|
||||
'200':
|
||||
description: Delete status code
|
||||
content: { }
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/networks/{networkId}/routers:
|
||||
get:
|
||||
summary: List all Network Routers
|
||||
description: Returns a list of all routers in a network
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: networkId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Array of Routers
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/NetworkRouter'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
post:
|
||||
summary: Create a Network Router
|
||||
description: Creates a Network Router
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: networkId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network
|
||||
requestBody:
|
||||
description: New Network Router request
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/NetworkRouterRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A Router Object
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/NetworkRouter'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/networks/{networkId}/routers/{routerId}:
|
||||
get:
|
||||
summary: Retrieve a Network Router
|
||||
description: Get information about a Network Router
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: networkId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network
|
||||
- in: path
|
||||
name: routerId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a router
|
||||
responses:
|
||||
'200':
|
||||
description: A Router object
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/NetworkRouter'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
put:
|
||||
summary: Update a Network Router
|
||||
description: Update a Network Router
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: networkId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network
|
||||
- in: path
|
||||
name: routerId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a router
|
||||
requestBody:
|
||||
description: Update Network Router request
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/NetworkRouterRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A Router object
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/NetworkRouter'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
delete:
|
||||
summary: Delete a Network Router
|
||||
description: Delete a network router
|
||||
tags: [ Networks ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: networkId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a network
|
||||
- in: path
|
||||
name: routerId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a router
|
||||
responses:
|
||||
'200':
|
||||
description: Delete status code
|
||||
content: { }
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/dns/nameservers:
|
||||
get:
|
||||
summary: List all Nameserver Groups
|
||||
|
||||
@@ -88,6 +88,13 @@ const (
|
||||
NameserverNsTypeUdp NameserverNsType = "udp"
|
||||
)
|
||||
|
||||
// Defines values for NetworkResourceType.
|
||||
const (
|
||||
NetworkResourceTypeDomain NetworkResourceType = "domain"
|
||||
NetworkResourceTypeHost NetworkResourceType = "host"
|
||||
NetworkResourceTypeSubnet NetworkResourceType = "subnet"
|
||||
)
|
||||
|
||||
// Defines values for PeerNetworkRangeCheckAction.
|
||||
const (
|
||||
PeerNetworkRangeCheckActionAllow PeerNetworkRangeCheckAction = "allow"
|
||||
@@ -494,6 +501,93 @@ type NameserverGroupRequest struct {
|
||||
SearchDomainsEnabled bool `json:"search_domains_enabled"`
|
||||
}
|
||||
|
||||
// Network defines model for Network.
|
||||
type Network struct {
|
||||
// Description Network description
|
||||
Description *string `json:"description,omitempty"`
|
||||
|
||||
// Id Network ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Name Network name
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// NetworkRequest defines model for NetworkRequest.
|
||||
type NetworkRequest struct {
|
||||
// Description Network description
|
||||
Description *string `json:"description,omitempty"`
|
||||
|
||||
// Name Network name
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// NetworkResource defines model for NetworkResource.
|
||||
type NetworkResource struct {
|
||||
// Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com)
|
||||
Address string `json:"address"`
|
||||
|
||||
// Description Network resource description
|
||||
Description *string `json:"description,omitempty"`
|
||||
|
||||
// Id Network Resource ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Name Network resource name
|
||||
Name string `json:"name"`
|
||||
|
||||
// Type Network resource type based of the address
|
||||
Type NetworkResourceType `json:"type"`
|
||||
}
|
||||
|
||||
// NetworkResourceType Network resource type based of the address
|
||||
type NetworkResourceType string
|
||||
|
||||
// NetworkResourceRequest defines model for NetworkResourceRequest.
|
||||
type NetworkResourceRequest struct {
|
||||
// Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com)
|
||||
Address string `json:"address"`
|
||||
|
||||
// Description Network resource description
|
||||
Description *string `json:"description,omitempty"`
|
||||
|
||||
// Name Network resource name
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// NetworkRouter defines model for NetworkRouter.
|
||||
type NetworkRouter struct {
|
||||
// Id Network Router Id
|
||||
Id string `json:"id"`
|
||||
|
||||
// Masquerade Indicate if peer should masquerade traffic to this route's prefix
|
||||
Masquerade bool `json:"masquerade"`
|
||||
|
||||
// Metric Route metric number. Lowest number has higher priority
|
||||
Metric int `json:"metric"`
|
||||
|
||||
// Peer Peer Identifier associated with route. This property can not be set together with `peer_groups`
|
||||
Peer *string `json:"peer,omitempty"`
|
||||
|
||||
// PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer`
|
||||
PeerGroups *[]string `json:"peer_groups,omitempty"`
|
||||
}
|
||||
|
||||
// NetworkRouterRequest defines model for NetworkRouterRequest.
|
||||
type NetworkRouterRequest struct {
|
||||
// Masquerade Indicate if peer should masquerade traffic to this route's prefix
|
||||
Masquerade bool `json:"masquerade"`
|
||||
|
||||
// Metric Route metric number. Lowest number has higher priority
|
||||
Metric int `json:"metric"`
|
||||
|
||||
// Peer Peer Identifier associated with route. This property can not be set together with `peer_groups`
|
||||
Peer *string `json:"peer,omitempty"`
|
||||
|
||||
// PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer`
|
||||
PeerGroups *[]string `json:"peer_groups,omitempty"`
|
||||
}
|
||||
|
||||
// OSVersionCheck Posture check for the version of operating system
|
||||
type OSVersionCheck struct {
|
||||
// Android Posture check for the version of operating system
|
||||
@@ -1292,6 +1386,24 @@ type PostApiGroupsJSONRequestBody = GroupRequest
|
||||
// PutApiGroupsGroupIdJSONRequestBody defines body for PutApiGroupsGroupId for application/json ContentType.
|
||||
type PutApiGroupsGroupIdJSONRequestBody = GroupRequest
|
||||
|
||||
// PostApiNetworksJSONRequestBody defines body for PostApiNetworks for application/json ContentType.
|
||||
type PostApiNetworksJSONRequestBody = NetworkRequest
|
||||
|
||||
// PutApiNetworksNetworkIdJSONRequestBody defines body for PutApiNetworksNetworkId for application/json ContentType.
|
||||
type PutApiNetworksNetworkIdJSONRequestBody = NetworkRequest
|
||||
|
||||
// PostApiNetworksNetworkIdResourcesJSONRequestBody defines body for PostApiNetworksNetworkIdResources for application/json ContentType.
|
||||
type PostApiNetworksNetworkIdResourcesJSONRequestBody = NetworkResourceRequest
|
||||
|
||||
// PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody defines body for PutApiNetworksNetworkIdResourcesResourceId for application/json ContentType.
|
||||
type PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody = NetworkResourceRequest
|
||||
|
||||
// PostApiNetworksNetworkIdRoutersJSONRequestBody defines body for PostApiNetworksNetworkIdRouters for application/json ContentType.
|
||||
type PostApiNetworksNetworkIdRoutersJSONRequestBody = NetworkRouterRequest
|
||||
|
||||
// PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody defines body for PutApiNetworksNetworkIdRoutersRouterId for application/json ContentType.
|
||||
type PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody = NetworkRouterRequest
|
||||
|
||||
// PutApiPeersPeerIdJSONRequestBody defines body for PutApiPeersPeerId for application/json ContentType.
|
||||
type PutApiPeersPeerIdJSONRequestBody = PeerRequest
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/events"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/groups"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/networks"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/peers"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/policies"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
|
||||
@@ -93,6 +94,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
||||
routes.AddEndpoints(api.AccountManager, authCfg, router)
|
||||
dns.AddEndpoints(api.AccountManager, authCfg, router)
|
||||
events.AddEndpoints(api.AccountManager, authCfg, router)
|
||||
networks.AddEndpoints(api.AccountManager.GetNetworksManager(), api.AccountManager.GetAccountIDFromToken, authCfg, router)
|
||||
|
||||
return rootRouter, nil
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// handler is a handler that handles the server.Account HTTP endpoints
|
||||
@@ -82,7 +83,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
settings := &server.Settings{
|
||||
settings := &types.Settings{
|
||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
|
||||
@@ -138,7 +139,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
|
||||
func toAccountResponse(accountID string, settings *types.Settings) *api.Account {
|
||||
jwtAllowGroups := settings.JWTAllowGroups
|
||||
if jwtAllowGroups == nil {
|
||||
jwtAllowGroups = []string{}
|
||||
|
||||
@@ -13,23 +13,23 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func initAccountsTestData(account *server.Account, admin *server.User) *handler {
|
||||
func initAccountsTestData(account *types.Account, admin *types.User) *handler {
|
||||
return &handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return account.Id, admin.Id, nil
|
||||
},
|
||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
|
||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||
return account.Settings, nil
|
||||
},
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||
@@ -58,19 +58,19 @@ func initAccountsTestData(account *server.Account, admin *server.User) *handler
|
||||
|
||||
func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
adminUser := types.NewAdminUser("test_user")
|
||||
|
||||
sr := func(v string) *string { return &v }
|
||||
br := func(v bool) *bool { return &v }
|
||||
|
||||
handler := initAccountsTestData(&server.Account{
|
||||
handler := initAccountsTestData(&types.Account{
|
||||
Id: accountID,
|
||||
Domain: "hotmail.com",
|
||||
Network: server.NewNetwork(),
|
||||
Users: map[string]*server.User{
|
||||
Network: types.NewNetwork(),
|
||||
Users: map[string]*types.User{
|
||||
adminUser.Id: adminUser,
|
||||
},
|
||||
Settings: &server.Settings{
|
||||
Settings: &types.Settings{
|
||||
PeerLoginExpirationEnabled: false,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
RegularUsersViewBlocked: true,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// dnsSettingsHandler is a handler that returns the DNS settings of the account
|
||||
@@ -81,7 +82,7 @@ func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Re
|
||||
return
|
||||
}
|
||||
|
||||
updateDNSSettings := &server.DNSSettings{
|
||||
updateDNSSettings := &types.DNSSettings{
|
||||
DisabledManagementGroups: req.DisabledManagementGroups,
|
||||
}
|
||||
|
||||
|
||||
@@ -13,10 +13,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
@@ -27,15 +27,15 @@ const (
|
||||
testDNSSettingsUserID = "test_user"
|
||||
)
|
||||
|
||||
var baseExistingDNSSettings = server.DNSSettings{
|
||||
var baseExistingDNSSettings = types.DNSSettings{
|
||||
DisabledManagementGroups: []string{testDNSSettingsExistingGroup},
|
||||
}
|
||||
|
||||
var testingDNSSettingsAccount = &server.Account{
|
||||
var testingDNSSettingsAccount = &types.Account{
|
||||
Id: testDNSSettingsAccountID,
|
||||
Domain: "hotmail.com",
|
||||
Users: map[string]*server.User{
|
||||
testDNSSettingsUserID: server.NewAdminUser("test_user"),
|
||||
Users: map[string]*types.User{
|
||||
testDNSSettingsUserID: types.NewAdminUser("test_user"),
|
||||
},
|
||||
DNSSettings: baseExistingDNSSettings,
|
||||
}
|
||||
@@ -43,10 +43,10 @@ var testingDNSSettingsAccount = &server.Account{
|
||||
func initDNSSettingsTestData() *dnsSettingsHandler {
|
||||
return &dnsSettingsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
|
||||
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
|
||||
return &testingDNSSettingsAccount.DNSSettings, nil
|
||||
},
|
||||
SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
|
||||
SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error {
|
||||
if dnsSettingsToSave != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,11 +13,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func initEventsTestData(account string, events ...*activity.Event) *handler {
|
||||
@@ -32,8 +32,8 @@ func initEventsTestData(account string, events ...*activity.Event) *handler {
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||
return make([]*server.UserInfo, 0), nil
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
|
||||
return make([]*types.UserInfo, 0), nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
@@ -191,7 +191,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
||||
},
|
||||
}
|
||||
accountID := "test_account"
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
adminUser := types.NewAdminUser("test_user")
|
||||
events := generateEvents(accountID, adminUser.Id)
|
||||
handler := initEventsTestData(accountID, events...)
|
||||
|
||||
|
||||
180
management/server/http/handlers/networks/handler.go
Normal file
180
management/server/http/handlers/networks/handler.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package networks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// handler is a handler that returns networks of the account
|
||||
type handler struct {
|
||||
networksManager networks.Manager
|
||||
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
func AddEndpoints(networksManager networks.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
|
||||
networksHandler := newHandler(networksManager, extractFromToken, authCfg)
|
||||
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
|
||||
addRouterEndpoints(networksManager.GetRouterManager(), extractFromToken, authCfg, router)
|
||||
addResourceEndpoints(networksManager.GetResourceManager(), extractFromToken, authCfg, router)
|
||||
}
|
||||
|
||||
func newHandler(networksManager networks.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler {
|
||||
return &handler{
|
||||
networksManager: networksManager,
|
||||
extractFromToken: extractFromToken,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var networkResponse []*api.Network
|
||||
for _, network := range networks {
|
||||
networkResponse = append(networkResponse, network.ToAPIResponse())
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, networkResponse)
|
||||
}
|
||||
|
||||
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.NetworkRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
network := &types.Network{}
|
||||
network.FromAPIRequest(&req)
|
||||
|
||||
network.AccountID = accountID
|
||||
network, err = h.networksManager.CreateNetwork(r.Context(), userID, network)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.NetworkRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
network := &types.Network{}
|
||||
network.FromAPIRequest(&req)
|
||||
|
||||
network.ID = networkID
|
||||
network.AccountID = accountID
|
||||
network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
162
management/server/http/handlers/networks/resources_handler.go
Normal file
162
management/server/http/handlers/networks/resources_handler.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package networks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
)
|
||||
|
||||
type resourceHandler struct {
|
||||
resourceManager resources.Manager
|
||||
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
func addResourceEndpoints(resourcesManager resources.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
|
||||
resourceHandler := newResourceHandler(resourcesManager, extractFromToken, authCfg)
|
||||
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResources).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newResourceHandler(resourceManager resources.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *resourceHandler {
|
||||
return &resourceHandler{
|
||||
resourceManager: resourceManager,
|
||||
extractFromToken: extractFromToken,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *resourceHandler) getAllResources(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resources, err := h.resourceManager.GetAllResources(r.Context(), accountID, userID, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var resourcesResponse []*api.NetworkResource
|
||||
for _, resource := range resources {
|
||||
resourcesResponse = append(resourcesResponse, resource.ToAPIResponse())
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resourcesResponse)
|
||||
}
|
||||
|
||||
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.NetworkResourceRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
resource := &types.NetworkResource{}
|
||||
resource.FromAPIRequest(&req)
|
||||
|
||||
resource.NetworkID = mux.Vars(r)["networkId"]
|
||||
resource.AccountID = accountID
|
||||
resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resourceID := mux.Vars(r)["resourceId"]
|
||||
resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.NetworkResourceRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
resource := &types.NetworkResource{}
|
||||
resource.FromAPIRequest(&req)
|
||||
|
||||
resource.ID = mux.Vars(r)["resourceId"]
|
||||
resource.NetworkID = mux.Vars(r)["networkId"]
|
||||
resource.AccountID = accountID
|
||||
resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resourceID := mux.Vars(r)["resourceId"]
|
||||
err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
165
management/server/http/handlers/networks/routers_handler.go
Normal file
165
management/server/http/handlers/networks/routers_handler.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package networks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
)
|
||||
|
||||
type routersHandler struct {
|
||||
routersManager routers.Manager
|
||||
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
|
||||
routersHandler := newRoutersHandler(routersManager, extractFromToken, authCfg)
|
||||
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newRoutersHandler(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *routersHandler {
|
||||
return &routersHandler{
|
||||
routersManager: routersManager,
|
||||
extractFromToken: extractFromToken,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
routers, err := h.routersManager.GetAllRouters(r.Context(), accountID, userID, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var routersResponse []*api.NetworkRouter
|
||||
for _, router := range routers {
|
||||
routersResponse = append(routersResponse, router.ToAPIResponse())
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, routersResponse)
|
||||
}
|
||||
|
||||
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
var req api.NetworkRouterRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
router := &types.NetworkRouter{}
|
||||
router.FromAPIRequest(&req)
|
||||
|
||||
router.NetworkID = networkID
|
||||
router.AccountID = accountID
|
||||
|
||||
router, err = h.routersManager.CreateRouter(r.Context(), userID, router)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routerID := mux.Vars(r)["routerId"]
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.NetworkRouterRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
router := &types.NetworkRouter{}
|
||||
router.FromAPIRequest(&req)
|
||||
|
||||
router.NetworkID = mux.Vars(r)["networkId"]
|
||||
router.ID = mux.Vars(r)["routerId"]
|
||||
router.AccountID = accountID
|
||||
|
||||
router, err = h.routersManager.UpdateRouter(r.Context(), userID, router)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routerID := mux.Vars(r)["routerId"]
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, struct{}{})
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// Handler is a handler that returns peers of the account
|
||||
@@ -57,7 +58,7 @@ func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
return peerToReturn, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
|
||||
func (h *Handler) getPeer(ctx context.Context, account *types.Account, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
@@ -84,7 +85,7 @@ func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID,
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
||||
}
|
||||
|
||||
func (h *Handler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) updatePeer(ctx context.Context, account *types.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
req := &api.PeerRequest{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@@ -295,7 +296,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
|
||||
func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer {
|
||||
func toAccessiblePeers(netMap *types.NetworkMap, dnsDomain string) []api.AccessiblePeer {
|
||||
accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers))
|
||||
for _, p := range netMap.Peers {
|
||||
accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
|
||||
|
||||
@@ -15,11 +15,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
@@ -73,18 +73,18 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer.Copy()
|
||||
}
|
||||
|
||||
policy := &server.Policy{
|
||||
policy := &types.Policy{
|
||||
ID: "policy",
|
||||
AccountID: accountID,
|
||||
Name: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*server.PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "rule",
|
||||
Name: "rule",
|
||||
@@ -99,16 +99,16 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
|
||||
},
|
||||
}
|
||||
|
||||
srvUser := server.NewRegularUser(serviceUser)
|
||||
srvUser := types.NewRegularUser(serviceUser)
|
||||
srvUser.IsServiceUser = true
|
||||
|
||||
account := &server.Account{
|
||||
account := &types.Account{
|
||||
Id: accountID,
|
||||
Domain: "hotmail.com",
|
||||
Peers: peersMap,
|
||||
Users: map[string]*server.User{
|
||||
adminUser: server.NewAdminUser(adminUser),
|
||||
regularUser: server.NewRegularUser(regularUser),
|
||||
Users: map[string]*types.User{
|
||||
adminUser: types.NewAdminUser(adminUser),
|
||||
regularUser: types.NewRegularUser(regularUser),
|
||||
serviceUser: srvUser,
|
||||
},
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
@@ -120,12 +120,12 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
|
||||
Peers: maps.Keys(peersMap),
|
||||
},
|
||||
},
|
||||
Settings: &server.Settings{
|
||||
Settings: &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
},
|
||||
Policies: []*server.Policy{policy},
|
||||
Network: &server.Network{
|
||||
Policies: []*types.Policy{policy},
|
||||
Network: &types.Network{
|
||||
Identifier: "ciclqisab2ss43jdn8q0",
|
||||
Net: net.IPNet{
|
||||
IP: net.ParseIP("100.67.0.0"),
|
||||
|
||||
@@ -13,11 +13,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -46,8 +46,8 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler {
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
|
||||
return server.NewAdminUser(id), nil
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
|
||||
return types.NewAdminUser(id), nil
|
||||
},
|
||||
},
|
||||
geolocationManager: geo,
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// handler is a handler that returns policy of the account
|
||||
@@ -133,7 +134,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
return
|
||||
}
|
||||
|
||||
policy := &server.Policy{
|
||||
policy := &types.Policy{
|
||||
ID: policyID,
|
||||
AccountID: accountID,
|
||||
Name: req.Name,
|
||||
@@ -146,7 +147,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
ruleID = *rule.Id
|
||||
}
|
||||
|
||||
pr := server.PolicyRule{
|
||||
pr := types.PolicyRule{
|
||||
ID: ruleID,
|
||||
PolicyID: policyID,
|
||||
Name: rule.Name,
|
||||
@@ -162,9 +163,9 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
|
||||
switch rule.Action {
|
||||
case api.PolicyRuleUpdateActionAccept:
|
||||
pr.Action = server.PolicyTrafficActionAccept
|
||||
pr.Action = types.PolicyTrafficActionAccept
|
||||
case api.PolicyRuleUpdateActionDrop:
|
||||
pr.Action = server.PolicyTrafficActionDrop
|
||||
pr.Action = types.PolicyTrafficActionDrop
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w)
|
||||
return
|
||||
@@ -172,13 +173,13 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
|
||||
switch rule.Protocol {
|
||||
case api.PolicyRuleUpdateProtocolAll:
|
||||
pr.Protocol = server.PolicyRuleProtocolALL
|
||||
pr.Protocol = types.PolicyRuleProtocolALL
|
||||
case api.PolicyRuleUpdateProtocolTcp:
|
||||
pr.Protocol = server.PolicyRuleProtocolTCP
|
||||
pr.Protocol = types.PolicyRuleProtocolTCP
|
||||
case api.PolicyRuleUpdateProtocolUdp:
|
||||
pr.Protocol = server.PolicyRuleProtocolUDP
|
||||
pr.Protocol = types.PolicyRuleProtocolUDP
|
||||
case api.PolicyRuleUpdateProtocolIcmp:
|
||||
pr.Protocol = server.PolicyRuleProtocolICMP
|
||||
pr.Protocol = types.PolicyRuleProtocolICMP
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
|
||||
return
|
||||
@@ -205,7 +206,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
|
||||
return
|
||||
}
|
||||
pr.PortRanges = append(pr.PortRanges, server.RulePortRange{
|
||||
pr.PortRanges = append(pr.PortRanges, types.RulePortRange{
|
||||
Start: uint16(portRange.Start),
|
||||
End: uint16(portRange.End),
|
||||
})
|
||||
@@ -214,7 +215,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
|
||||
// validate policy object
|
||||
switch pr.Protocol {
|
||||
case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP:
|
||||
case types.PolicyRuleProtocolALL, types.PolicyRuleProtocolICMP:
|
||||
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
|
||||
return
|
||||
@@ -223,7 +224,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
||||
return
|
||||
}
|
||||
case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP:
|
||||
case types.PolicyRuleProtocolTCP, types.PolicyRuleProtocolUDP:
|
||||
if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
||||
return
|
||||
@@ -319,7 +320,7 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy {
|
||||
func toPolicyResponse(groups []*nbgroup.Group, policy *types.Policy) *api.Policy {
|
||||
groupsMap := make(map[string]*nbgroup.Group)
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
@@ -20,25 +21,24 @@ import (
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
func initPoliciesTestData(policies ...*server.Policy) *handler {
|
||||
testPolicies := make(map[string]*server.Policy, len(policies))
|
||||
func initPoliciesTestData(policies ...*types.Policy) *handler {
|
||||
testPolicies := make(map[string]*types.Policy, len(policies))
|
||||
for _, policy := range policies {
|
||||
testPolicies[policy.ID] = policy
|
||||
}
|
||||
return &handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) {
|
||||
GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*types.Policy, error) {
|
||||
policy, ok := testPolicies[policyID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "policy not found")
|
||||
}
|
||||
return policy, nil
|
||||
},
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *types.Policy) (*types.Policy, error) {
|
||||
if !strings.HasPrefix(policy.ID, "id-") {
|
||||
policy.ID = "id-was-set"
|
||||
policy.Rules[0].ID = "id-was-set"
|
||||
@@ -51,19 +51,19 @@ func initPoliciesTestData(policies ...*server.Policy) *handler {
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||
user := server.NewAdminUser(userID)
|
||||
return &server.Account{
|
||||
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
|
||||
user := types.NewAdminUser(userID)
|
||||
return &types.Account{
|
||||
Id: accountID,
|
||||
Domain: "hotmail.com",
|
||||
Policies: []*server.Policy{
|
||||
Policies: []*types.Policy{
|
||||
{ID: "id-existed"},
|
||||
},
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
"F": {ID: "F"},
|
||||
"G": {ID: "G"},
|
||||
},
|
||||
Users: map[string]*server.User{
|
||||
Users: map[string]*types.User{
|
||||
"test_user": user,
|
||||
},
|
||||
}, nil
|
||||
@@ -105,10 +105,10 @@ func TestPoliciesGetPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
policy := &server.Policy{
|
||||
policy := &types.Policy{
|
||||
ID: "idofthepolicy",
|
||||
Name: "Rule",
|
||||
Rules: []*server.PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{ID: "idoftherule", Name: "Rule"},
|
||||
},
|
||||
}
|
||||
@@ -251,10 +251,10 @@ func TestPoliciesWritePolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p := initPoliciesTestData(&server.Policy{
|
||||
p := initPoliciesTestData(&types.Policy{
|
||||
ID: "id-existed",
|
||||
Name: "Default POSTed Rule",
|
||||
Rules: []*server.PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "id-existed",
|
||||
Name: "Default POSTed Rule",
|
||||
|
||||
@@ -16,13 +16,13 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/magiconair/properties/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
@@ -61,7 +61,7 @@ var baseExistingRoute = &route.Route{
|
||||
Groups: []string{existingGroupID},
|
||||
}
|
||||
|
||||
var testingAccount = &server.Account{
|
||||
var testingAccount = &types.Account{
|
||||
Id: testAccountID,
|
||||
Domain: "hotmail.com",
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
@@ -82,8 +82,8 @@ var testingAccount = &server.Account{
|
||||
},
|
||||
},
|
||||
},
|
||||
Users: map[string]*server.User{
|
||||
"test_user": server.NewAdminUser("test_user"),
|
||||
Users: map[string]*types.User{
|
||||
"test_user": types.NewAdminUser("test_user"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// handler is a handler that returns a list of setup keys of the account
|
||||
@@ -63,8 +64,8 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable ||
|
||||
server.SetupKeyType(req.Type) == server.SetupKeyOneOff) {
|
||||
if !(types.SetupKeyType(req.Type) == types.SetupKeyReusable ||
|
||||
types.SetupKeyType(req.Type) == types.SetupKeyOneOff) {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w)
|
||||
return
|
||||
}
|
||||
@@ -85,7 +86,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
ephemeral = *req.Ephemeral
|
||||
}
|
||||
|
||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups, req.UsageLimit, userID, ephemeral)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -152,7 +153,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
newKey := &server.SetupKey{}
|
||||
newKey := &types.SetupKey{}
|
||||
newKey.AutoGroups = req.AutoGroups
|
||||
newKey.Revoked = req.Revoked
|
||||
newKey.Id = keyID
|
||||
@@ -212,7 +213,7 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {
|
||||
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
err := json.NewEncoder(w).Encode(toResponseBody(key))
|
||||
@@ -222,7 +223,7 @@ func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupK
|
||||
}
|
||||
}
|
||||
|
||||
func toResponseBody(key *server.SetupKey) *api.SetupKey {
|
||||
func toResponseBody(key *types.SetupKey) *api.SetupKey {
|
||||
var state string
|
||||
switch {
|
||||
case key.IsExpired():
|
||||
|
||||
@@ -14,11 +14,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -29,17 +29,17 @@ const (
|
||||
testAccountID = "test_id"
|
||||
)
|
||||
|
||||
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
|
||||
user *server.User,
|
||||
func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey,
|
||||
user *types.User,
|
||||
) *handler {
|
||||
return &handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
|
||||
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ types.SetupKeyType, _ time.Duration, _ []string,
|
||||
_ int, _ string, ephemeral bool,
|
||||
) (*server.SetupKey, error) {
|
||||
) (*types.SetupKey, error) {
|
||||
if keyName == newKey.Name || typ != newKey.Type {
|
||||
nk := newKey.Copy()
|
||||
nk.Ephemeral = ephemeral
|
||||
@@ -47,7 +47,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
}
|
||||
return nil, fmt.Errorf("failed creating setup key")
|
||||
},
|
||||
GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*server.SetupKey, error) {
|
||||
GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*types.SetupKey, error) {
|
||||
switch keyID {
|
||||
case defaultKey.Id:
|
||||
return defaultKey, nil
|
||||
@@ -58,15 +58,15 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
}
|
||||
},
|
||||
|
||||
SaveSetupKeyFunc: func(_ context.Context, accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) {
|
||||
SaveSetupKeyFunc: func(_ context.Context, accountID string, key *types.SetupKey, _ string) (*types.SetupKey, error) {
|
||||
if key.Id == updatedSetupKey.Id {
|
||||
return updatedSetupKey, nil
|
||||
}
|
||||
return nil, status.Errorf(status.NotFound, "key %s not found", key.Id)
|
||||
},
|
||||
|
||||
ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) {
|
||||
return []*server.SetupKey{defaultKey}, nil
|
||||
ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*types.SetupKey, error) {
|
||||
return []*types.SetupKey{defaultKey}, nil
|
||||
},
|
||||
|
||||
DeleteSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) error {
|
||||
@@ -89,13 +89,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
}
|
||||
|
||||
func TestSetupKeysHandlers(t *testing.T) {
|
||||
defaultSetupKey, _ := server.GenerateDefaultSetupKey()
|
||||
defaultSetupKey, _ := types.GenerateDefaultSetupKey()
|
||||
defaultSetupKey.Id = existingSetupKeyID
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
adminUser := types.NewAdminUser("test_user")
|
||||
|
||||
newSetupKey, plainKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
|
||||
server.SetupKeyUnlimitedUsage, true)
|
||||
newSetupKey, plainKey := types.GenerateSetupKey(newSetupKeyName, types.SetupKeyReusable, 0, []string{"group-1"},
|
||||
types.SetupKeyUnlimitedUsage, true)
|
||||
newSetupKey.Key = plainKey
|
||||
updatedDefaultSetupKey := defaultSetupKey.Copy()
|
||||
updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// patHandler is the nameserver group handler of the account
|
||||
@@ -164,7 +165,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {
|
||||
func toPATResponse(pat *types.PersonalAccessToken) *api.PersonalAccessToken {
|
||||
var lastUsed *time.Time
|
||||
if !pat.LastUsed.IsZero() {
|
||||
lastUsed = &pat.LastUsed
|
||||
@@ -179,7 +180,7 @@ func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {
|
||||
}
|
||||
}
|
||||
|
||||
func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated {
|
||||
func toPATGeneratedResponse(pat *types.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated {
|
||||
return &api.PersonalAccessTokenGenerated{
|
||||
PlainToken: pat.PlainToken,
|
||||
PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken),
|
||||
|
||||
@@ -14,11 +14,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -31,13 +31,13 @@ const (
|
||||
testDomain = "hotmail.com"
|
||||
)
|
||||
|
||||
var testAccount = &server.Account{
|
||||
var testAccount = &types.Account{
|
||||
Id: existingAccountID,
|
||||
Domain: testDomain,
|
||||
Users: map[string]*server.User{
|
||||
Users: map[string]*types.User{
|
||||
existingUserID: {
|
||||
Id: existingUserID,
|
||||
PATs: map[string]*server.PersonalAccessToken{
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
existingTokenID: {
|
||||
ID: existingTokenID,
|
||||
Name: "My first token",
|
||||
@@ -64,16 +64,16 @@ var testAccount = &server.Account{
|
||||
func initPATTestData() *patHandler {
|
||||
return &patHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
return &server.PersonalAccessTokenGenerated{
|
||||
return &types.PersonalAccessTokenGenerated{
|
||||
PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe",
|
||||
PersonalAccessToken: server.PersonalAccessToken{},
|
||||
PersonalAccessToken: types.PersonalAccessToken{},
|
||||
}, nil
|
||||
},
|
||||
|
||||
@@ -92,7 +92,7 @@ func initPATTestData() *patHandler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
@@ -104,14 +104,14 @@ func initPATTestData() *patHandler {
|
||||
}
|
||||
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
|
||||
},
|
||||
GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
|
||||
return []*types.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
@@ -217,7 +217,7 @@ func TestTokenHandlers(t *testing.T) {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assert.NotEmpty(t, got.PlainToken)
|
||||
assert.Equal(t, server.PATLength, len(got.PlainToken))
|
||||
assert.Equal(t, types.PATLength, len(got.PlainToken))
|
||||
case "Get All Tokens":
|
||||
expectedTokens := []api.PersonalAccessToken{
|
||||
toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]),
|
||||
@@ -243,7 +243,7 @@ func TestTokenHandlers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken {
|
||||
func toTokenResponse(serverToken types.PersonalAccessToken) api.PersonalAccessToken {
|
||||
return api.PersonalAccessToken{
|
||||
Id: serverToken.ID,
|
||||
Name: serverToken.Name,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@@ -83,13 +84,13 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userRole := server.StrRoleToUserRole(req.Role)
|
||||
if userRole == server.UserRoleUnknown {
|
||||
userRole := types.StrRoleToUserRole(req.Role)
|
||||
if userRole == types.UserRoleUnknown {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{
|
||||
Id: targetUserID,
|
||||
Role: userRole,
|
||||
AutoGroups: req.AutoGroups,
|
||||
@@ -156,7 +157,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
|
||||
if types.StrRoleToUserRole(req.Role) == types.UserRoleUnknown {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
|
||||
return
|
||||
}
|
||||
@@ -171,13 +172,13 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
name = *req.Name
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: req.Role,
|
||||
AutoGroups: req.AutoGroups,
|
||||
IsServiceUser: req.IsServiceUser,
|
||||
Issued: server.UserIssuedAPI,
|
||||
Issued: types.UserIssuedAPI,
|
||||
})
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -264,7 +265,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||
func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
|
||||
autoGroups := user.AutoGroups
|
||||
if autoGroups == nil {
|
||||
autoGroups = []string{}
|
||||
|
||||
@@ -13,11 +13,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -26,37 +26,37 @@ const (
|
||||
regularUserID = "regularUserID"
|
||||
)
|
||||
|
||||
var usersTestAccount = &server.Account{
|
||||
var usersTestAccount = &types.Account{
|
||||
Id: existingAccountID,
|
||||
Domain: testDomain,
|
||||
Users: map[string]*server.User{
|
||||
Users: map[string]*types.User{
|
||||
existingUserID: {
|
||||
Id: existingUserID,
|
||||
Role: "admin",
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group_1"},
|
||||
Issued: server.UserIssuedAPI,
|
||||
Issued: types.UserIssuedAPI,
|
||||
},
|
||||
regularUserID: {
|
||||
Id: regularUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group_1"},
|
||||
Issued: server.UserIssuedAPI,
|
||||
Issued: types.UserIssuedAPI,
|
||||
},
|
||||
serviceUserID: {
|
||||
Id: serviceUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
AutoGroups: []string{"group_1"},
|
||||
Issued: server.UserIssuedAPI,
|
||||
Issued: types.UserIssuedAPI,
|
||||
},
|
||||
nonDeletableServiceUserID: {
|
||||
Id: serviceUserID,
|
||||
Role: "admin",
|
||||
IsServiceUser: true,
|
||||
NonDeletable: true,
|
||||
Issued: server.UserIssuedIntegration,
|
||||
Issued: types.UserIssuedIntegration,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -67,13 +67,13 @@ func initUsersTestData() *handler {
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return usersTestAccount.Id, claims.UserId, nil
|
||||
},
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
|
||||
return usersTestAccount.Users[id], nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||
users := make([]*server.UserInfo, 0)
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
|
||||
users := make([]*types.UserInfo, 0)
|
||||
for _, v := range usersTestAccount.Users {
|
||||
users = append(users, &server.UserInfo{
|
||||
users = append(users, &types.UserInfo{
|
||||
ID: v.Id,
|
||||
Role: string(v.Role),
|
||||
Name: "",
|
||||
@@ -85,7 +85,7 @@ func initUsersTestData() *handler {
|
||||
}
|
||||
return users, nil
|
||||
},
|
||||
CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
|
||||
CreateUserFunc: func(_ context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) {
|
||||
if userID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
|
||||
}
|
||||
@@ -100,7 +100,7 @@ func initUsersTestData() *handler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) {
|
||||
SaveUserFunc: func(_ context.Context, accountID, userID string, update *types.User) (*types.UserInfo, error) {
|
||||
if update.Id == notFoundUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
|
||||
}
|
||||
@@ -109,7 +109,7 @@ func initUsersTestData() *handler {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
|
||||
}
|
||||
|
||||
info, err := update.Copy().ToUserInfo(nil, &server.Settings{RegularUsersViewBlocked: false})
|
||||
info, err := update.Copy().ToUserInfo(nil, &types.Settings{RegularUsersViewBlocked: false})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -175,7 +175,7 @@ func TestGetUsers(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
respBody := []*server.UserInfo{}
|
||||
respBody := []*types.UserInfo{}
|
||||
err = json.Unmarshal(content, &respBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
@@ -342,7 +342,7 @@ func TestCreateUser(t *testing.T) {
|
||||
requestType string
|
||||
requestPath string
|
||||
requestBody io.Reader
|
||||
expectedResult []*server.User
|
||||
expectedResult []*types.User
|
||||
}{
|
||||
{name: "CreateServiceUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(serviceUserString)},
|
||||
// right now creation is blocked in AC middleware, will be refactored in the future
|
||||
|
||||
@@ -7,16 +7,16 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
|
||||
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
|
||||
|
||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||
type AccessControl struct {
|
||||
|
||||
@@ -11,16 +11,16 @@ import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// GetAccountFromPATFunc function
|
||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
|
||||
|
||||
// ValidateAndParseTokenFunc function
|
||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,13 +28,13 @@ const (
|
||||
wrongToken = "wrongToken"
|
||||
)
|
||||
|
||||
var testAccount = &server.Account{
|
||||
var testAccount = &types.Account{
|
||||
Id: accountID,
|
||||
Domain: domain,
|
||||
Users: map[string]*server.User{
|
||||
Users: map[string]*types.User{
|
||||
userID: {
|
||||
Id: userID,
|
||||
PATs: map[string]*server.PersonalAccessToken{
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
tokenID: {
|
||||
ID: tokenID,
|
||||
Name: "My first token",
|
||||
@@ -49,7 +49,7 @@ var testAccount = &server.Account{
|
||||
},
|
||||
}
|
||||
|
||||
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
func mockGetAccountFromPAT(_ context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
|
||||
if token == PAT {
|
||||
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account.
|
||||
@@ -57,9 +59,9 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
|
||||
return true, nil
|
||||
}
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
_, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -73,6 +75,6 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) {
|
||||
func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) {
|
||||
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
|
||||
}
|
||||
|
||||
@@ -23,7 +23,9 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -413,7 +415,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -618,7 +620,7 @@ func testSyncStatusRace(t *testing.T) {
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -705,7 +707,7 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false)
|
||||
setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), types.SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false)
|
||||
if err != nil {
|
||||
t.Logf("error creating setup key: %v", err)
|
||||
return
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@@ -532,7 +533,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc.
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
s := grpc.NewServer()
|
||||
|
||||
store, _, err := server.NewTestStoreFromSQL(context.Background(), testFile, dataDir)
|
||||
store, _, err := store.NewTestStoreFromSQL(context.Background(), testFile, dataDir)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||
}
|
||||
|
||||
@@ -15,7 +15,8 @@ import (
|
||||
"github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -47,8 +48,8 @@ type properties map[string]interface{}
|
||||
|
||||
// DataSource metric data source
|
||||
type DataSource interface {
|
||||
GetAllAccounts(ctx context.Context) []*server.Account
|
||||
GetStoreEngine() server.StoreEngine
|
||||
GetAllAccounts(ctx context.Context) []*types.Account
|
||||
GetStoreEngine() store.Engine
|
||||
}
|
||||
|
||||
// ConnManager peer connection manager that holds state for current active connections
|
||||
|
||||
@@ -5,10 +5,11 @@ import (
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -22,12 +23,12 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
|
||||
}
|
||||
|
||||
// GetAllAccounts returns a list of *server.Account for use in tests with predefined information
|
||||
func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
|
||||
return []*server.Account{
|
||||
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
return []*types.Account{
|
||||
{
|
||||
Id: "1",
|
||||
Settings: &server.Settings{PeerLoginExpirationEnabled: true},
|
||||
SetupKeys: map[string]*server.SetupKey{
|
||||
Settings: &types.Settings{PeerLoginExpirationEnabled: true},
|
||||
SetupKeys: map[string]*types.SetupKey{
|
||||
"1": {
|
||||
Id: "1",
|
||||
Ephemeral: true,
|
||||
@@ -49,20 +50,20 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
|
||||
Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"},
|
||||
},
|
||||
},
|
||||
Policies: []*server.Policy{
|
||||
Policies: []*types.Policy{
|
||||
{
|
||||
Rules: []*server.PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Bidirectional: true,
|
||||
Protocol: server.PolicyRuleProtocolTCP,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Rules: []*server.PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Bidirectional: false,
|
||||
Protocol: server.PolicyRuleProtocolTCP,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{"1"},
|
||||
@@ -94,16 +95,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
|
||||
},
|
||||
},
|
||||
},
|
||||
Users: map[string]*server.User{
|
||||
Users: map[string]*types.User{
|
||||
"1": {
|
||||
IsServiceUser: true,
|
||||
PATs: map[string]*server.PersonalAccessToken{
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"1": {},
|
||||
},
|
||||
},
|
||||
"2": {
|
||||
IsServiceUser: false,
|
||||
PATs: map[string]*server.PersonalAccessToken{
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"1": {},
|
||||
},
|
||||
},
|
||||
@@ -111,8 +112,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
|
||||
},
|
||||
{
|
||||
Id: "2",
|
||||
Settings: &server.Settings{PeerLoginExpirationEnabled: true},
|
||||
SetupKeys: map[string]*server.SetupKey{
|
||||
Settings: &types.Settings{PeerLoginExpirationEnabled: true},
|
||||
SetupKeys: map[string]*types.SetupKey{
|
||||
"1": {
|
||||
Id: "1",
|
||||
Ephemeral: true,
|
||||
@@ -134,20 +135,20 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
|
||||
Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"},
|
||||
},
|
||||
},
|
||||
Policies: []*server.Policy{
|
||||
Policies: []*types.Policy{
|
||||
{
|
||||
Rules: []*server.PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Bidirectional: true,
|
||||
Protocol: server.PolicyRuleProtocolTCP,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Rules: []*server.PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Bidirectional: false,
|
||||
Protocol: server.PolicyRuleProtocolTCP,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -158,16 +159,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
|
||||
PeerGroups: make([]string, 1),
|
||||
},
|
||||
},
|
||||
Users: map[string]*server.User{
|
||||
Users: map[string]*types.User{
|
||||
"1": {
|
||||
IsServiceUser: true,
|
||||
PATs: map[string]*server.PersonalAccessToken{
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"1": {},
|
||||
},
|
||||
},
|
||||
"2": {
|
||||
IsServiceUser: false,
|
||||
PATs: map[string]*server.PersonalAccessToken{
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"1": {},
|
||||
},
|
||||
},
|
||||
@@ -177,8 +178,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
|
||||
}
|
||||
|
||||
// GetStoreEngine returns FileStoreEngine
|
||||
func (mockDatasource) GetStoreEngine() server.StoreEngine {
|
||||
return server.FileStoreEngine
|
||||
func (mockDatasource) GetStoreEngine() store.Engine {
|
||||
return store.FileStoreEngine
|
||||
}
|
||||
|
||||
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
|
||||
@@ -267,7 +268,7 @@ func TestGenerateProperties(t *testing.T) {
|
||||
t.Errorf("expected 2 user_peers, got %d", properties["user_peers"])
|
||||
}
|
||||
|
||||
if properties["store_engine"] != server.FileStoreEngine {
|
||||
if properties["store_engine"] != store.FileStoreEngine {
|
||||
t.Errorf("expected JsonFile, got %s", properties["store_engine"])
|
||||
}
|
||||
|
||||
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -31,64 +31,64 @@ func setupDatabase(t *testing.T) *gorm.DB {
|
||||
|
||||
func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net")
|
||||
err := migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net")
|
||||
require.NoError(t, err, "Migration should not fail for an empty database")
|
||||
}
|
||||
|
||||
func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.Account{}, &route.Route{})
|
||||
err := db.AutoMigrate(&types.Account{}, &route.Route{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
|
||||
require.NoError(t, err, "Failed to parse CIDR")
|
||||
|
||||
type network struct {
|
||||
server.Network
|
||||
types.Network
|
||||
Net net.IPNet `gorm:"serializer:gob"`
|
||||
}
|
||||
|
||||
type account struct {
|
||||
server.Account
|
||||
types.Account
|
||||
Network *network `gorm:"embedded;embeddedPrefix:network_"`
|
||||
}
|
||||
|
||||
err = db.Save(&account{Account: server.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error
|
||||
err = db.Save(&account{Account: types.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error
|
||||
require.NoError(t, err, "Failed to insert Gob data")
|
||||
|
||||
var gobStr string
|
||||
err = db.Model(&server.Account{}).Select("network_net").First(&gobStr).Error
|
||||
err = db.Model(&types.Account{}).Select("network_net").First(&gobStr).Error
|
||||
assert.NoError(t, err, "Failed to fetch Gob data")
|
||||
|
||||
err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet)
|
||||
require.NoError(t, err, "Failed to decode Gob data")
|
||||
|
||||
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net")
|
||||
err = migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net")
|
||||
require.NoError(t, err, "Migration should not fail with Gob data")
|
||||
|
||||
var jsonStr string
|
||||
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
|
||||
db.Model(&types.Account{}).Select("network_net").First(&jsonStr)
|
||||
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be migrated")
|
||||
}
|
||||
|
||||
func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.Account{}, &route.Route{})
|
||||
err := db.AutoMigrate(&types.Account{}, &route.Route{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
|
||||
require.NoError(t, err, "Failed to parse CIDR")
|
||||
|
||||
err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error
|
||||
err = db.Save(&types.Account{Network: &types.Network{Net: *ipnet}}).Error
|
||||
require.NoError(t, err, "Failed to insert JSON data")
|
||||
|
||||
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net")
|
||||
err = migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net")
|
||||
require.NoError(t, err, "Migration should not fail with JSON data")
|
||||
|
||||
var jsonStr string
|
||||
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
|
||||
db.Model(&types.Account{}).Select("network_net").First(&jsonStr)
|
||||
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged")
|
||||
}
|
||||
|
||||
@@ -101,7 +101,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
|
||||
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
|
||||
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
type location struct {
|
||||
@@ -115,12 +115,12 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
|
||||
}
|
||||
|
||||
type account struct {
|
||||
server.Account
|
||||
types.Account
|
||||
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
|
||||
}
|
||||
|
||||
err = db.Save(&account{
|
||||
Account: server.Account{Id: "123"},
|
||||
Account: types.Account{Id: "123"},
|
||||
Peers: []peer{
|
||||
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
}},
|
||||
@@ -142,10 +142,10 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
|
||||
func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
|
||||
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&server.Account{
|
||||
err = db.Save(&types.Account{
|
||||
Id: "1234",
|
||||
PeersG: []nbpeer.Peer{
|
||||
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
@@ -164,20 +164,20 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
||||
func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.SetupKey{})
|
||||
err := db.AutoMigrate(&types.SetupKey{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&server.SetupKey{
|
||||
err = db.Save(&types.SetupKey{
|
||||
Id: "1",
|
||||
Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382",
|
||||
}).Error
|
||||
require.NoError(t, err, "Failed to insert setup key")
|
||||
|
||||
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
|
||||
err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db)
|
||||
require.NoError(t, err, "Migration should not fail to migrate setup key")
|
||||
|
||||
var key server.SetupKey
|
||||
err = db.Model(&server.SetupKey{}).First(&key).Error
|
||||
var key types.SetupKey
|
||||
err = db.Model(&types.SetupKey{}).First(&key).Error
|
||||
assert.NoError(t, err, "Failed to fetch setup key")
|
||||
|
||||
assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret")
|
||||
@@ -187,21 +187,21 @@ func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
|
||||
func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.SetupKey{})
|
||||
err := db.AutoMigrate(&types.SetupKey{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&server.SetupKey{
|
||||
err = db.Save(&types.SetupKey{
|
||||
Id: "1",
|
||||
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||
KeySecret: "EEFDA****",
|
||||
}).Error
|
||||
require.NoError(t, err, "Failed to insert setup key")
|
||||
|
||||
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
|
||||
err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db)
|
||||
require.NoError(t, err, "Migration should not fail to migrate setup key")
|
||||
|
||||
var key server.SetupKey
|
||||
err = db.Model(&server.SetupKey{}).First(&key).Error
|
||||
var key types.SetupKey
|
||||
err = db.Model(&types.SetupKey{}).First(&key).Error
|
||||
assert.NoError(t, err, "Failed to fetch setup key")
|
||||
|
||||
assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret")
|
||||
@@ -211,20 +211,20 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.
|
||||
func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.SetupKey{})
|
||||
err := db.AutoMigrate(&types.SetupKey{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&server.SetupKey{
|
||||
err = db.Save(&types.SetupKey{
|
||||
Id: "1",
|
||||
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||
}).Error
|
||||
require.NoError(t, err, "Failed to insert setup key")
|
||||
|
||||
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
|
||||
err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db)
|
||||
require.NoError(t, err, "Migration should not fail to migrate setup key")
|
||||
|
||||
var key server.SetupKey
|
||||
err = db.Model(&server.SetupKey{}).First(&key).Error
|
||||
var key types.SetupKey
|
||||
err = db.Model(&types.SetupKey{}).First(&key).Error
|
||||
assert.NoError(t, err, "Failed to fetch setup key")
|
||||
|
||||
assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed")
|
||||
|
||||
@@ -16,28 +16,30 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
|
||||
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error)
|
||||
GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error)
|
||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error)
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
||||
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error)
|
||||
AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
||||
AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*group.Group, error)
|
||||
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error)
|
||||
@@ -48,12 +50,12 @@ type MockAccountManager struct {
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error)
|
||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error)
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error)
|
||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
|
||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@@ -62,35 +64,35 @@ type MockAccountManager struct {
|
||||
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||
ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
||||
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error)
|
||||
SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
||||
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
|
||||
SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error)
|
||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error)
|
||||
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
|
||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
|
||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||
GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
||||
GetDNSDomainFunc func() string
|
||||
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
|
||||
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
||||
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error)
|
||||
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
@@ -105,12 +107,17 @@ type MockAccountManager struct {
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
|
||||
GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
|
||||
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
|
||||
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||
GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error)
|
||||
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetNetworksManager() networks.Manager {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
|
||||
if am.DeleteSetupKeyFunc != nil {
|
||||
return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID)
|
||||
@@ -118,7 +125,7 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
if am.SyncAndMarkPeerFunc != nil {
|
||||
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
|
||||
}
|
||||
@@ -130,7 +137,7 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
|
||||
func (am *MockAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) {
|
||||
approvedPeers := make(map[string]struct{})
|
||||
for id := range account.Peers {
|
||||
approvedPeers[id] = struct{}{}
|
||||
@@ -155,7 +162,7 @@ func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userI
|
||||
}
|
||||
|
||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*server.UserInfo, error) {
|
||||
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*types.UserInfo, error) {
|
||||
if am.GetUsersFromAccountFunc != nil {
|
||||
return am.GetUsersFromAccountFunc(ctx, accountID, userID)
|
||||
}
|
||||
@@ -173,7 +180,7 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID,
|
||||
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetOrCreateAccountByUser(
|
||||
ctx context.Context, userId, domain string,
|
||||
) (*server.Account, error) {
|
||||
) (*types.Account, error) {
|
||||
if am.GetOrCreateAccountByUserFunc != nil {
|
||||
return am.GetOrCreateAccountByUserFunc(ctx, userId, domain)
|
||||
}
|
||||
@@ -188,13 +195,13 @@ func (am *MockAccountManager) CreateSetupKey(
|
||||
ctx context.Context,
|
||||
accountID string,
|
||||
keyName string,
|
||||
keyType server.SetupKeyType,
|
||||
keyType types.SetupKeyType,
|
||||
expiresIn time.Duration,
|
||||
autoGroups []string,
|
||||
usageLimit int,
|
||||
userID string,
|
||||
ephemeral bool,
|
||||
) (*server.SetupKey, error) {
|
||||
) (*types.SetupKey, error) {
|
||||
if am.CreateSetupKeyFunc != nil {
|
||||
return am.CreateSetupKeyFunc(ctx, accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral)
|
||||
}
|
||||
@@ -221,7 +228,7 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId,
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *types.Account) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
|
||||
}
|
||||
@@ -229,7 +236,7 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
|
||||
}
|
||||
|
||||
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
|
||||
if am.GetAccountFromPATFunc != nil {
|
||||
return am.GetAccountFromPATFunc(ctx, pat)
|
||||
}
|
||||
@@ -253,7 +260,7 @@ func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error
|
||||
}
|
||||
|
||||
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
if am.CreatePATFunc != nil {
|
||||
return am.CreatePATFunc(ctx, accountID, initiatorUserID, targetUserID, name, expiresIn)
|
||||
}
|
||||
@@ -269,7 +276,7 @@ func (am *MockAccountManager) DeletePAT(ctx context.Context, accountID string, i
|
||||
}
|
||||
|
||||
// GetPAT mock implementation of GetPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) {
|
||||
if am.GetPATFunc != nil {
|
||||
return am.GetPATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID)
|
||||
}
|
||||
@@ -277,7 +284,7 @@ func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, init
|
||||
}
|
||||
|
||||
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) {
|
||||
if am.GetAllPATsFunc != nil {
|
||||
return am.GetAllPATsFunc(ctx, accountID, initiatorUserID, targetUserID)
|
||||
}
|
||||
@@ -285,7 +292,7 @@ func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string,
|
||||
}
|
||||
|
||||
// GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*server.NetworkMap, error) {
|
||||
func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*types.NetworkMap, error) {
|
||||
if am.GetNetworkMapFunc != nil {
|
||||
return am.GetNetworkMapFunc(ctx, peerKey)
|
||||
}
|
||||
@@ -293,7 +300,7 @@ func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string)
|
||||
}
|
||||
|
||||
// GetPeerNetwork mock implementation of GetPeerNetwork from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*server.Network, error) {
|
||||
func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*types.Network, error) {
|
||||
if am.GetPeerNetworkFunc != nil {
|
||||
return am.GetPeerNetworkFunc(ctx, peerKey)
|
||||
}
|
||||
@@ -306,7 +313,7 @@ func (am *MockAccountManager) AddPeer(
|
||||
setupKey string,
|
||||
userId string,
|
||||
peer *nbpeer.Peer,
|
||||
) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
if am.AddPeerFunc != nil {
|
||||
return am.AddPeerFunc(ctx, setupKey, userId, peer)
|
||||
}
|
||||
@@ -378,7 +385,7 @@ func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID,
|
||||
}
|
||||
|
||||
// GetPolicy mock implementation of GetPolicy from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) {
|
||||
func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) {
|
||||
if am.GetPolicyFunc != nil {
|
||||
return am.GetPolicyFunc(ctx, accountID, policyID, userID)
|
||||
}
|
||||
@@ -386,7 +393,7 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
|
||||
}
|
||||
|
||||
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) {
|
||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) {
|
||||
if am.SavePolicyFunc != nil {
|
||||
return am.SavePolicyFunc(ctx, accountID, userID, policy)
|
||||
}
|
||||
@@ -402,7 +409,7 @@ func (am *MockAccountManager) DeletePolicy(ctx context.Context, accountID, polic
|
||||
}
|
||||
|
||||
// ListPolicies mock implementation of ListPolicies from server.AccountManager interface
|
||||
func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*server.Policy, error) {
|
||||
func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) {
|
||||
if am.ListPoliciesFunc != nil {
|
||||
return am.ListPoliciesFunc(ctx, accountID, userID)
|
||||
}
|
||||
@@ -418,14 +425,14 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string,
|
||||
}
|
||||
|
||||
// GetUser mock implementation of GetUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) {
|
||||
func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) {
|
||||
if am.GetUserFunc != nil {
|
||||
return am.GetUserFunc(ctx, claims)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*server.User, error) {
|
||||
func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
|
||||
if am.ListUsersFunc != nil {
|
||||
return am.ListUsersFunc(ctx, accountID)
|
||||
}
|
||||
@@ -481,7 +488,7 @@ func (am *MockAccountManager) ListRoutes(ctx context.Context, accountID, userID
|
||||
}
|
||||
|
||||
// SaveSetupKey mocks SaveSetupKey of the AccountManager interface
|
||||
func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) {
|
||||
func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) {
|
||||
if am.SaveSetupKeyFunc != nil {
|
||||
return am.SaveSetupKeyFunc(ctx, accountID, key, userID)
|
||||
}
|
||||
@@ -490,7 +497,7 @@ func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string
|
||||
}
|
||||
|
||||
// GetSetupKey mocks GetSetupKey of the AccountManager interface
|
||||
func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) {
|
||||
func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) {
|
||||
if am.GetSetupKeyFunc != nil {
|
||||
return am.GetSetupKeyFunc(ctx, accountID, userID, keyID)
|
||||
}
|
||||
@@ -499,7 +506,7 @@ func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID
|
||||
}
|
||||
|
||||
// ListSetupKeys mocks ListSetupKeys of the AccountManager interface
|
||||
func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) {
|
||||
func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) {
|
||||
if am.ListSetupKeysFunc != nil {
|
||||
return am.ListSetupKeysFunc(ctx, accountID, userID)
|
||||
}
|
||||
@@ -508,7 +515,7 @@ func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
// SaveUser mocks SaveUser of the AccountManager interface
|
||||
func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) {
|
||||
func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) {
|
||||
if am.SaveUserFunc != nil {
|
||||
return am.SaveUserFunc(ctx, accountID, userID, user)
|
||||
}
|
||||
@@ -516,7 +523,7 @@ func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID st
|
||||
}
|
||||
|
||||
// SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface
|
||||
func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) {
|
||||
func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) {
|
||||
if am.SaveOrAddUserFunc != nil {
|
||||
return am.SaveOrAddUserFunc(ctx, accountID, userID, user, addIfNotExists)
|
||||
}
|
||||
@@ -524,7 +531,7 @@ func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
// SaveOrAddUsers mocks SaveOrAddUsers of the AccountManager interface
|
||||
func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) {
|
||||
func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) {
|
||||
if am.SaveOrAddUsersFunc != nil {
|
||||
return am.SaveOrAddUsersFunc(ctx, accountID, userID, users, addIfNotExists)
|
||||
}
|
||||
@@ -595,7 +602,7 @@ func (am *MockAccountManager) ListNameServerGroups(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
// CreateUser mocks CreateUser of the AccountManager interface
|
||||
func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) {
|
||||
func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) {
|
||||
if am.CreateUserFunc != nil {
|
||||
return am.CreateUserFunc(ctx, accountID, userID, invite)
|
||||
}
|
||||
@@ -642,7 +649,7 @@ func (am *MockAccountManager) GetEvents(ctx context.Context, accountID, userID s
|
||||
}
|
||||
|
||||
// GetDNSSettings mocks GetDNSSettings of the AccountManager interface
|
||||
func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
|
||||
func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
|
||||
if am.GetDNSSettingsFunc != nil {
|
||||
return am.GetDNSSettingsFunc(ctx, accountID, userID)
|
||||
}
|
||||
@@ -650,7 +657,7 @@ func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID stri
|
||||
}
|
||||
|
||||
// SaveDNSSettings mocks SaveDNSSettings of the AccountManager interface
|
||||
func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
|
||||
func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error {
|
||||
if am.SaveDNSSettingsFunc != nil {
|
||||
return am.SaveDNSSettingsFunc(ctx, accountID, userID, dnsSettingsToSave)
|
||||
}
|
||||
@@ -666,7 +673,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us
|
||||
}
|
||||
|
||||
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) {
|
||||
if am.UpdateAccountSettingsFunc != nil {
|
||||
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
|
||||
}
|
||||
@@ -674,7 +681,7 @@ func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, account
|
||||
}
|
||||
|
||||
// LoginPeer mocks LoginPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
if am.LoginPeerFunc != nil {
|
||||
return am.LoginPeerFunc(ctx, login)
|
||||
}
|
||||
@@ -682,7 +689,7 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo
|
||||
}
|
||||
|
||||
// SyncPeer mocks SyncPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
if am.SyncPeerFunc != nil {
|
||||
return am.SyncPeerFunc(ctx, sync, account)
|
||||
}
|
||||
@@ -803,7 +810,7 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe
|
||||
}
|
||||
|
||||
// GetAccountByID mocks GetAccountByID of the AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||
func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
|
||||
if am.GetAccountByIDFunc != nil {
|
||||
return am.GetAccountByIDFunc(ctx, accountID, userID)
|
||||
}
|
||||
@@ -811,21 +818,21 @@ func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID stri
|
||||
}
|
||||
|
||||
// GetUserByID mocks GetUserByID of the AccountManager interface
|
||||
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) {
|
||||
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||
if am.GetUserByIDFunc != nil {
|
||||
return am.GetUserByIDFunc(ctx, id)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
|
||||
func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||
if am.GetAccountSettingsFunc != nil {
|
||||
return am.GetAccountSettingsFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) {
|
||||
func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
|
||||
if am.GetAccountFunc != nil {
|
||||
return am.GetAccountFunc(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -13,13 +13,14 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
||||
|
||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -32,7 +33,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
|
||||
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID)
|
||||
}
|
||||
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
@@ -40,7 +41,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -64,7 +65,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -74,11 +75,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup)
|
||||
return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, newNSGroup)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -102,7 +103,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -113,8 +114,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -129,11 +130,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave)
|
||||
return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, nsGroupToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -153,7 +154,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -165,8 +166,8 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
var nsGroup *nbdns.NameServerGroup
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -176,11 +177,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID)
|
||||
return transaction.DeleteNameServerGroup(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -197,7 +198,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
|
||||
// ListNameServerGroups returns a list of nameserver groups from account
|
||||
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -210,10 +211,10 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
|
||||
func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
|
||||
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -224,7 +225,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s
|
||||
return err
|
||||
}
|
||||
|
||||
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -234,7 +235,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s
|
||||
return err
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -243,7 +244,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s
|
||||
}
|
||||
|
||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
|
||||
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
|
||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -772,10 +774,10 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (Store, error) {
|
||||
func createNSStore(t *testing.T) (store.Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -784,7 +786,7 @@ func createNSStore(t *testing.T) (Store, error) {
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
|
||||
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) {
|
||||
t.Helper()
|
||||
peer1 := &nbpeer.Peer{
|
||||
Key: nsGroupPeer1Key,
|
||||
|
||||
63
management/server/networks/manager.go
Normal file
63
management/server/networks/manager.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package networks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error)
|
||||
CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error)
|
||||
GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error)
|
||||
UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error)
|
||||
DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error
|
||||
GetResourceManager() resources.Manager
|
||||
GetRouterManager() routers.Manager
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
routersManager routers.Manager
|
||||
resourcesManager resources.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store.Store) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
routersManager: routers.NewManager(store),
|
||||
resourcesManager: resources.NewManager(store),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetResourceManager() resources.Manager {
|
||||
return m.resourcesManager
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetRouterManager() routers.Manager {
|
||||
return m.routersManager
|
||||
}
|
||||
47
management/server/networks/resources/manager.go
Normal file
47
management/server/networks/resources/manager.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package resources
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetAllResources(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error)
|
||||
CreateResource(ctx context.Context, accountID string, resource *types.NetworkResource) (*types.NetworkResource, error)
|
||||
GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error)
|
||||
UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error)
|
||||
DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
func NewManager(store store.Store) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllResources(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateResource(ctx context.Context, accountID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package networks
|
||||
package types
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -8,14 +8,16 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
type NetworkResourceType string
|
||||
|
||||
const (
|
||||
host NetworkResourceType = "Host"
|
||||
subnet NetworkResourceType = "Subnet"
|
||||
domain NetworkResourceType = "Domain"
|
||||
host NetworkResourceType = "host"
|
||||
subnet NetworkResourceType = "subnet"
|
||||
domain NetworkResourceType = "domain"
|
||||
)
|
||||
|
||||
func (p NetworkResourceType) String() string {
|
||||
@@ -49,6 +51,25 @@ func NewNetworkResource(accountID, networkID, name, description, address string)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (n *NetworkResource) ToAPIResponse() *api.NetworkResource {
|
||||
return &api.NetworkResource{
|
||||
Id: n.ID,
|
||||
Name: n.Name,
|
||||
Description: &n.Description,
|
||||
Type: api.NetworkResourceType(n.Type.String()),
|
||||
Address: n.Address,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) {
|
||||
n.Name = req.Name
|
||||
n.Description = ""
|
||||
if req.Description != nil {
|
||||
n.Description = *req.Description
|
||||
}
|
||||
n.Address = req.Address
|
||||
}
|
||||
|
||||
func (n *NetworkResource) Copy() *NetworkResource {
|
||||
return &NetworkResource{
|
||||
ID: n.ID,
|
||||
@@ -1,4 +1,4 @@
|
||||
package networks
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
47
management/server/networks/routers/manager.go
Normal file
47
management/server/networks/routers/manager.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package routers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetAllRouters(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error)
|
||||
CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error)
|
||||
GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error)
|
||||
UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error)
|
||||
DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
func NewManager(store store.Store) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllRouters(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
package networks
|
||||
package types
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
type NetworkRouter struct {
|
||||
@@ -32,6 +34,23 @@ func NewNetworkRouter(accountID string, networkID string, peer string, peerGroup
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter {
|
||||
return &api.NetworkRouter{
|
||||
Id: n.ID,
|
||||
Peer: &n.Peer,
|
||||
PeerGroups: &n.PeerGroups,
|
||||
Masquerade: n.Masquerade,
|
||||
Metric: n.Metric,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) {
|
||||
n.Peer = *req.Peer
|
||||
n.PeerGroups = *req.PeerGroups
|
||||
n.Masquerade = req.Masquerade
|
||||
n.Metric = req.Metric
|
||||
}
|
||||
|
||||
func (n *NetworkRouter) Copy() *NetworkRouter {
|
||||
return &NetworkRouter{
|
||||
ID: n.ID,
|
||||
@@ -1,4 +1,4 @@
|
||||
package networks
|
||||
package types
|
||||
|
||||
import "testing"
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package networks
|
||||
package types
|
||||
|
||||
import "github.com/rs/xid"
|
||||
import (
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
type Network struct {
|
||||
ID string `gorm:"index"`
|
||||
@@ -18,6 +22,19 @@ func NewNetwork(accountId, name, description string) *Network {
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Network) ToAPIResponse() *api.Network {
|
||||
return &api.Network{
|
||||
Id: n.ID,
|
||||
Name: n.Name,
|
||||
Description: &n.Description,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Network) FromAPIRequest(req *api.NetworkRequest) {
|
||||
n.Name = req.Name
|
||||
n.Description = *req.Description
|
||||
}
|
||||
|
||||
// Copy returns a copy of a posture checks.
|
||||
func (n *Network) Copy() *Network {
|
||||
return &Network{
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -92,7 +94,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
|
||||
// fetch all the peers that have access to the user's peers
|
||||
for _, peer := range peers {
|
||||
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
|
||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
|
||||
for _, p := range aclPeers {
|
||||
peersMap[p.ID] = p
|
||||
}
|
||||
@@ -107,7 +109,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
}
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error {
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *types.Account) error {
|
||||
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find peer by pub key: %w", err)
|
||||
@@ -139,7 +141,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) {
|
||||
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *types.Account) (bool, error) {
|
||||
oldStatus := peer.Status.Copy()
|
||||
newStatus := oldStatus
|
||||
newStatus.LastSeen = time.Now().UTC()
|
||||
@@ -213,9 +215,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
if peerLabelUpdated {
|
||||
peer.Name = update.Name
|
||||
|
||||
existingLabels := account.getPeerDNSLabels()
|
||||
existingLabels := account.GetPeerDNSLabels()
|
||||
|
||||
newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
|
||||
newLabel, err := types.GetPeerHostLabel(peer.Name, existingLabels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -278,7 +280,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
// deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock
|
||||
func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error {
|
||||
func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *types.Account, peerIDs []string, userID string) error {
|
||||
|
||||
// the first loop is needed to ensure all peers present under the account before modifying, otherwise
|
||||
// we might have some inconsistencies
|
||||
@@ -316,7 +318,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
|
||||
FirewallRulesIsEmpty: true,
|
||||
},
|
||||
},
|
||||
NetworkMap: &NetworkMap{},
|
||||
NetworkMap: &types.NetworkMap{},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
|
||||
@@ -358,7 +360,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
}
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) {
|
||||
func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
|
||||
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -383,7 +385,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
|
||||
}
|
||||
|
||||
// GetPeerNetwork returns the Network for a given peer
|
||||
func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) {
|
||||
func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) {
|
||||
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -399,7 +401,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri
|
||||
// to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied
|
||||
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
|
||||
// The peer property is just a placeholder for the Peer properties to pass further
|
||||
func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
if setupKey == "" && userID == "" {
|
||||
// no auth method provided => reject access
|
||||
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
||||
@@ -433,7 +435,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
// and the peer disconnects with a timeout and tries to register again.
|
||||
// We just check if this machine has been registered before and reject the second registration.
|
||||
// The connecting peer should be able to recover with a retry.
|
||||
_, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
|
||||
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key)
|
||||
if err == nil {
|
||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
||||
}
|
||||
@@ -446,12 +448,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
var newPeer *nbpeer.Peer
|
||||
var groupsToAdd []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var setupKeyID string
|
||||
var setupKeyName string
|
||||
var ephemeral bool
|
||||
if addedByUser {
|
||||
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
|
||||
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user groups: %w", err)
|
||||
}
|
||||
@@ -460,7 +462,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
opEvent.Activity = activity.PeerAddedByUser
|
||||
} else {
|
||||
// Validate the setup key
|
||||
sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, encodedHashedKey)
|
||||
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get setup key: %w", err)
|
||||
}
|
||||
@@ -533,7 +535,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
}
|
||||
|
||||
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
@@ -558,7 +560,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
|
||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
@@ -627,18 +629,18 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return newPeer, networkMap, postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
|
||||
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
|
||||
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, s store.Store, accountID string) (net.IP, error) {
|
||||
takenIps, err := s.GetTakenIPs(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
|
||||
}
|
||||
|
||||
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
|
||||
network, err := s.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting network: %w", err)
|
||||
}
|
||||
|
||||
nextIp, err := AllocatePeerIP(network.Net, takenIps)
|
||||
nextIp, err := types.AllocatePeerIP(network.Net, takenIps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
|
||||
}
|
||||
@@ -647,7 +649,7 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc
|
||||
}
|
||||
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.NewPeerNotRegisteredError()
|
||||
@@ -695,7 +697,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
}
|
||||
|
||||
if peerNotValid {
|
||||
emptyMap := &NetworkMap{
|
||||
emptyMap := &types.NetworkMap{
|
||||
Network: account.Network.Copy(),
|
||||
}
|
||||
return peer, emptyMap, []*posture.Checks{}, nil
|
||||
@@ -710,7 +712,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login PeerLogin, err error) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
|
||||
// Try registering it.
|
||||
@@ -730,7 +732,7 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo
|
||||
|
||||
// LoginPeer logs in or registers a peer.
|
||||
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
|
||||
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return am.handlePeerLoginNotFound(ctx, login, err)
|
||||
@@ -755,12 +757,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
}()
|
||||
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -785,7 +787,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
}
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -849,7 +851,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
|
||||
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
|
||||
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -860,7 +862,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
|
||||
return nil
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -872,11 +874,11 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *types.Account, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
var postureChecks []*posture.Checks
|
||||
|
||||
if isRequiresApproval {
|
||||
emptyMap := &NetworkMap{
|
||||
emptyMap := &types.NetworkMap{
|
||||
Network: account.Network.Copy(),
|
||||
}
|
||||
return peer, emptyMap, nil, nil
|
||||
@@ -896,7 +898,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error {
|
||||
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *types.User, peer *nbpeer.Peer) error {
|
||||
err := checkAuth(ctx, user.Id, peer)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -918,7 +920,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error {
|
||||
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *types.User) error {
|
||||
if peer.AddedWithSSOLogin() {
|
||||
if user.IsBlocked() {
|
||||
return status.Errorf(status.PermissionDenied, "user is blocked")
|
||||
@@ -939,7 +941,7 @@ func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings) bool {
|
||||
func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Settings) bool {
|
||||
expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||
expired = settings.PeerLoginExpirationEnabled && expired
|
||||
if expired || peer.Status.LoginExpired {
|
||||
@@ -991,7 +993,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
}
|
||||
|
||||
for _, p := range userPeers {
|
||||
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap)
|
||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap)
|
||||
for _, aclPeer := range aclPeers {
|
||||
if aclPeer.ID == peerID {
|
||||
return peer, nil
|
||||
@@ -1069,7 +1071,7 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||
|
||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||
// in an active DNS, route, or ACL configuration.
|
||||
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) {
|
||||
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *types.Account, peerID string) (bool, error) {
|
||||
peerGroupIDs := make([]string, 0)
|
||||
for _, group := range account.Groups {
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
|
||||
@@ -27,7 +27,9 @@ import (
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -37,13 +39,13 @@ func TestPeer_LoginExpired(t *testing.T) {
|
||||
expirationEnabled bool
|
||||
lastLogin time.Time
|
||||
expected bool
|
||||
accountSettings *Settings
|
||||
accountSettings *types.Settings
|
||||
}{
|
||||
{
|
||||
name: "Peer Login Expiration Disabled. Peer Login Should Not Expire",
|
||||
expirationEnabled: false,
|
||||
lastLogin: time.Now().UTC().Add(-25 * time.Hour),
|
||||
accountSettings: &Settings{
|
||||
accountSettings: &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
},
|
||||
@@ -53,7 +55,7 @@ func TestPeer_LoginExpired(t *testing.T) {
|
||||
name: "Peer Login Should Expire",
|
||||
expirationEnabled: true,
|
||||
lastLogin: time.Now().UTC().Add(-25 * time.Hour),
|
||||
accountSettings: &Settings{
|
||||
accountSettings: &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
},
|
||||
@@ -63,7 +65,7 @@ func TestPeer_LoginExpired(t *testing.T) {
|
||||
name: "Peer Login Should Not Expire",
|
||||
expirationEnabled: true,
|
||||
lastLogin: time.Now().UTC(),
|
||||
accountSettings: &Settings{
|
||||
accountSettings: &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
},
|
||||
@@ -92,14 +94,14 @@ func TestPeer_SessionExpired(t *testing.T) {
|
||||
lastLogin time.Time
|
||||
connected bool
|
||||
expected bool
|
||||
accountSettings *Settings
|
||||
accountSettings *types.Settings
|
||||
}{
|
||||
{
|
||||
name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire",
|
||||
expirationEnabled: false,
|
||||
connected: false,
|
||||
lastLogin: time.Now().UTC().Add(-1 * time.Second),
|
||||
accountSettings: &Settings{
|
||||
accountSettings: &types.Settings{
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
},
|
||||
@@ -110,7 +112,7 @@ func TestPeer_SessionExpired(t *testing.T) {
|
||||
expirationEnabled: true,
|
||||
connected: false,
|
||||
lastLogin: time.Now().UTC().Add(-1 * time.Second),
|
||||
accountSettings: &Settings{
|
||||
accountSettings: &types.Settings{
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Second,
|
||||
},
|
||||
@@ -121,7 +123,7 @@ func TestPeer_SessionExpired(t *testing.T) {
|
||||
expirationEnabled: true,
|
||||
connected: true,
|
||||
lastLogin: time.Now().UTC(),
|
||||
accountSettings: &Settings{
|
||||
accountSettings: &types.Settings{
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Second,
|
||||
},
|
||||
@@ -161,7 +163,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userId, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@@ -233,9 +235,9 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var setupKey *SetupKey
|
||||
var setupKey *types.SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
if key.Type == SetupKeyReusable {
|
||||
if key.Type == types.SetupKeyReusable {
|
||||
setupKey = key
|
||||
}
|
||||
}
|
||||
@@ -303,16 +305,16 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
policy := &types.Policy{
|
||||
Name: "test",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{group1.ID},
|
||||
Destinations: []string{group2.ID},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -410,7 +412,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userId, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@@ -469,9 +471,9 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[someUser] = &User{
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: UserRoleUser,
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
account.Settings.RegularUsersViewBlocked = false
|
||||
|
||||
@@ -482,7 +484,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
}
|
||||
|
||||
// two peers one added by a regular user and one with a setup key
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@@ -567,77 +569,77 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
role UserRole
|
||||
role types.UserRole
|
||||
limitedViewSettings bool
|
||||
isServiceUser bool
|
||||
expectedPeerCount int
|
||||
}{
|
||||
{
|
||||
name: "Regular user, no limited view settings, not a service user",
|
||||
role: UserRoleUser,
|
||||
role: types.UserRoleUser,
|
||||
limitedViewSettings: false,
|
||||
isServiceUser: false,
|
||||
expectedPeerCount: 1,
|
||||
},
|
||||
{
|
||||
name: "Service user, no limited view settings",
|
||||
role: UserRoleUser,
|
||||
role: types.UserRoleUser,
|
||||
limitedViewSettings: false,
|
||||
isServiceUser: true,
|
||||
expectedPeerCount: 2,
|
||||
},
|
||||
{
|
||||
name: "Regular user, limited view settings",
|
||||
role: UserRoleUser,
|
||||
role: types.UserRoleUser,
|
||||
limitedViewSettings: true,
|
||||
isServiceUser: false,
|
||||
expectedPeerCount: 0,
|
||||
},
|
||||
{
|
||||
name: "Service user, limited view settings",
|
||||
role: UserRoleUser,
|
||||
role: types.UserRoleUser,
|
||||
limitedViewSettings: true,
|
||||
isServiceUser: true,
|
||||
expectedPeerCount: 2,
|
||||
},
|
||||
{
|
||||
name: "Admin, no limited view settings, not a service user",
|
||||
role: UserRoleAdmin,
|
||||
role: types.UserRoleAdmin,
|
||||
limitedViewSettings: false,
|
||||
isServiceUser: false,
|
||||
expectedPeerCount: 2,
|
||||
},
|
||||
{
|
||||
name: "Admin service user, no limited view settings",
|
||||
role: UserRoleAdmin,
|
||||
role: types.UserRoleAdmin,
|
||||
limitedViewSettings: false,
|
||||
isServiceUser: true,
|
||||
expectedPeerCount: 2,
|
||||
},
|
||||
{
|
||||
name: "Admin, limited view settings",
|
||||
role: UserRoleAdmin,
|
||||
role: types.UserRoleAdmin,
|
||||
limitedViewSettings: true,
|
||||
isServiceUser: false,
|
||||
expectedPeerCount: 2,
|
||||
},
|
||||
{
|
||||
name: "Admin Service user, limited view settings",
|
||||
role: UserRoleAdmin,
|
||||
role: types.UserRoleAdmin,
|
||||
limitedViewSettings: true,
|
||||
isServiceUser: true,
|
||||
expectedPeerCount: 2,
|
||||
},
|
||||
{
|
||||
name: "Owner, no limited view settings",
|
||||
role: UserRoleOwner,
|
||||
role: types.UserRoleOwner,
|
||||
limitedViewSettings: true,
|
||||
isServiceUser: false,
|
||||
expectedPeerCount: 2,
|
||||
},
|
||||
{
|
||||
name: "Owner, limited view settings",
|
||||
role: UserRoleOwner,
|
||||
role: types.UserRoleOwner,
|
||||
limitedViewSettings: true,
|
||||
isServiceUser: false,
|
||||
expectedPeerCount: 2,
|
||||
@@ -656,12 +658,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[someUser] = &User{
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: testCase.role,
|
||||
IsServiceUser: testCase.isServiceUser,
|
||||
}
|
||||
account.Policies = []*Policy{}
|
||||
account.Policies = []*types.Policy{}
|
||||
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
@@ -726,9 +728,9 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
|
||||
regularUser := "regular_user"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[regularUser] = &User{
|
||||
account.Users[regularUser] = &types.User{
|
||||
Id: regularUser,
|
||||
Role: UserRoleUser,
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
|
||||
// Create peers
|
||||
@@ -746,7 +748,7 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
|
||||
}
|
||||
|
||||
// Create groups and policies
|
||||
account.Policies = make([]*Policy, 0, groups)
|
||||
account.Policies = make([]*types.Policy, 0, groups)
|
||||
for i := 0; i < groups; i++ {
|
||||
groupID := fmt.Sprintf("group-%d", i)
|
||||
group := &nbgroup.Group{
|
||||
@@ -760,11 +762,11 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
|
||||
account.Groups[groupID] = group
|
||||
|
||||
// Create a policy for this group
|
||||
policy := &Policy{
|
||||
policy := &types.Policy{
|
||||
ID: fmt.Sprintf("policy-%d", i),
|
||||
Name: fmt.Sprintf("Policy for Group %d", i),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: fmt.Sprintf("rule-%d", i),
|
||||
Name: fmt.Sprintf("Rule for Group %d", i),
|
||||
@@ -772,8 +774,8 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
|
||||
Sources: []string{groupID},
|
||||
Destinations: []string{groupID},
|
||||
Bidirectional: true,
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolALL,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -939,8 +941,8 @@ func TestToSyncResponse(t *testing.T) {
|
||||
Payload: "turn-user",
|
||||
Signature: "turn-pass",
|
||||
}
|
||||
networkMap := &NetworkMap{
|
||||
Network: &Network{Net: *ipnet, Serial: 1000},
|
||||
networkMap := &types.NetworkMap{
|
||||
Network: &types.Network{Net: *ipnet, Serial: 1000},
|
||||
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
|
||||
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
|
||||
Routes: []*nbroute.Route{
|
||||
@@ -987,8 +989,8 @@ func TestToSyncResponse(t *testing.T) {
|
||||
},
|
||||
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
|
||||
},
|
||||
FirewallRules: []*FirewallRule{
|
||||
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
|
||||
FirewallRules: []*types.FirewallRule{
|
||||
{PeerIP: "192.168.1.2", Direction: types.FirewallRuleDirectionIN, Action: string(types.PolicyTrafficActionAccept), Protocol: string(types.PolicyRuleProtocolTCP), Port: "80"},
|
||||
},
|
||||
}
|
||||
dnsName := "example.com"
|
||||
@@ -1088,7 +1090,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1099,13 +1101,13 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
_, err = s.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
@@ -1128,12 +1130,12 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
|
||||
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||
assert.Equal(t, peer.UserID, existingUserID)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
account, err := s.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, account.Peers, addedPeer.ID)
|
||||
assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
|
||||
@@ -1152,7 +1154,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1163,13 +1165,13 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
_, err = s.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
@@ -1192,11 +1194,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
|
||||
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, peer.AccountID, existingAccountID)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
account, err := s.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, account.Peers, addedPeer.ID)
|
||||
assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
|
||||
@@ -1219,7 +1221,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1230,13 +1232,13 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
|
||||
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
_, err = s.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
@@ -1258,10 +1260,10 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
|
||||
require.Error(t, err)
|
||||
|
||||
_, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
|
||||
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key)
|
||||
require.Error(t, err)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
account, err := s.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, account.Peers, newPeer.ID)
|
||||
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
|
||||
@@ -1304,26 +1306,26 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// create a user with auto groups
|
||||
_, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*User{
|
||||
_, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*types.User{
|
||||
{
|
||||
Id: "regularUser1",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: types.UserRoleAdmin,
|
||||
Issued: types.UserIssuedAPI,
|
||||
AutoGroups: []string{"groupA"},
|
||||
},
|
||||
{
|
||||
Id: "regularUser2",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: types.UserRoleAdmin,
|
||||
Issued: types.UserIssuedAPI,
|
||||
AutoGroups: []string{"groupB"},
|
||||
},
|
||||
{
|
||||
Id: "regularUser3",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: types.UserRoleAdmin,
|
||||
Issued: types.UserIssuedAPI,
|
||||
AutoGroups: []string{"groupC"},
|
||||
},
|
||||
}, true)
|
||||
@@ -1464,15 +1466,15 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// Adding peer to group linked with policy should update account peers and send peer update
|
||||
t.Run("adding peer to group linked with policy", func(t *testing.T) {
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
@@ -3,344 +3,22 @@ package server
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// PolicyUpdateOperationType operation type
|
||||
type PolicyUpdateOperationType int
|
||||
|
||||
// PolicyTrafficActionType action type for the firewall
|
||||
type PolicyTrafficActionType string
|
||||
|
||||
// PolicyRuleProtocolType type of traffic
|
||||
type PolicyRuleProtocolType string
|
||||
|
||||
// PolicyRuleDirection direction of traffic
|
||||
type PolicyRuleDirection string
|
||||
|
||||
const (
|
||||
// PolicyTrafficActionAccept indicates that the traffic is accepted
|
||||
PolicyTrafficActionAccept = PolicyTrafficActionType("accept")
|
||||
// PolicyTrafficActionDrop indicates that the traffic is dropped
|
||||
PolicyTrafficActionDrop = PolicyTrafficActionType("drop")
|
||||
)
|
||||
|
||||
const (
|
||||
// PolicyRuleProtocolALL type of traffic
|
||||
PolicyRuleProtocolALL = PolicyRuleProtocolType("all")
|
||||
// PolicyRuleProtocolTCP type of traffic
|
||||
PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp")
|
||||
// PolicyRuleProtocolUDP type of traffic
|
||||
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
|
||||
// PolicyRuleProtocolICMP type of traffic
|
||||
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
|
||||
)
|
||||
|
||||
const (
|
||||
// PolicyRuleFlowDirect allows traffic from source to destination
|
||||
PolicyRuleFlowDirect = PolicyRuleDirection("direct")
|
||||
// PolicyRuleFlowBidirect allows traffic to both directions
|
||||
PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect")
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultRuleName is a name for the Default rule that is created for every account
|
||||
DefaultRuleName = "Default"
|
||||
// DefaultRuleDescription is a description for the Default rule that is created for every account
|
||||
DefaultRuleDescription = "This is a default rule that allows connections between all the resources"
|
||||
// DefaultPolicyName is a name for the Default policy that is created for every account
|
||||
DefaultPolicyName = "Default"
|
||||
// DefaultPolicyDescription is a description for the Default policy that is created for every account
|
||||
DefaultPolicyDescription = "This is a default policy that allows connections between all the resources"
|
||||
)
|
||||
|
||||
const (
|
||||
firewallRuleDirectionIN = 0
|
||||
firewallRuleDirectionOUT = 1
|
||||
)
|
||||
|
||||
// PolicyUpdateOperation operation object with type and values to be applied
|
||||
type PolicyUpdateOperation struct {
|
||||
Type PolicyUpdateOperationType
|
||||
Values []string
|
||||
}
|
||||
|
||||
// RulePortRange represents a range of ports for a firewall rule.
|
||||
type RulePortRange struct {
|
||||
Start uint16
|
||||
End uint16
|
||||
}
|
||||
|
||||
// PolicyRule is the metadata of the policy
|
||||
type PolicyRule struct {
|
||||
// ID of the policy rule
|
||||
ID string `gorm:"primaryKey"`
|
||||
|
||||
// PolicyID is a reference to Policy that this object belongs
|
||||
PolicyID string `json:"-" gorm:"index"`
|
||||
|
||||
// Name of the rule visible in the UI
|
||||
Name string
|
||||
|
||||
// Description of the rule visible in the UI
|
||||
Description string
|
||||
|
||||
// Enabled status of rule in the system
|
||||
Enabled bool
|
||||
|
||||
// Action policy accept or drops packets
|
||||
Action PolicyTrafficActionType
|
||||
|
||||
// Destinations policy destination groups
|
||||
Destinations []string `gorm:"serializer:json"`
|
||||
|
||||
// Sources policy source groups
|
||||
Sources []string `gorm:"serializer:json"`
|
||||
|
||||
// Bidirectional define if the rule is applicable in both directions, sources, and destinations
|
||||
Bidirectional bool
|
||||
|
||||
// Protocol type of the traffic
|
||||
Protocol PolicyRuleProtocolType
|
||||
|
||||
// Ports or it ranges list
|
||||
Ports []string `gorm:"serializer:json"`
|
||||
|
||||
// PortRanges a list of port ranges.
|
||||
PortRanges []RulePortRange `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// Copy returns a copy of a policy rule
|
||||
func (pm *PolicyRule) Copy() *PolicyRule {
|
||||
rule := &PolicyRule{
|
||||
ID: pm.ID,
|
||||
PolicyID: pm.PolicyID,
|
||||
Name: pm.Name,
|
||||
Description: pm.Description,
|
||||
Enabled: pm.Enabled,
|
||||
Action: pm.Action,
|
||||
Destinations: make([]string, len(pm.Destinations)),
|
||||
Sources: make([]string, len(pm.Sources)),
|
||||
Bidirectional: pm.Bidirectional,
|
||||
Protocol: pm.Protocol,
|
||||
Ports: make([]string, len(pm.Ports)),
|
||||
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
|
||||
}
|
||||
copy(rule.Destinations, pm.Destinations)
|
||||
copy(rule.Sources, pm.Sources)
|
||||
copy(rule.Ports, pm.Ports)
|
||||
copy(rule.PortRanges, pm.PortRanges)
|
||||
return rule
|
||||
}
|
||||
|
||||
// Policy of the Rego query
|
||||
type Policy struct {
|
||||
// ID of the policy'
|
||||
ID string `gorm:"primaryKey"`
|
||||
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// Name of the Policy
|
||||
Name string
|
||||
|
||||
// Description of the policy visible in the UI
|
||||
Description string
|
||||
|
||||
// Enabled status of the policy
|
||||
Enabled bool
|
||||
|
||||
// Rules of the policy
|
||||
Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
|
||||
// SourcePostureChecks are ID references to Posture checks for policy source groups
|
||||
SourcePostureChecks []string `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// Copy returns a copy of the policy.
|
||||
func (p *Policy) Copy() *Policy {
|
||||
c := &Policy{
|
||||
ID: p.ID,
|
||||
AccountID: p.AccountID,
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
Enabled: p.Enabled,
|
||||
Rules: make([]*PolicyRule, len(p.Rules)),
|
||||
SourcePostureChecks: make([]string, len(p.SourcePostureChecks)),
|
||||
}
|
||||
for i, r := range p.Rules {
|
||||
c.Rules[i] = r.Copy()
|
||||
}
|
||||
copy(c.SourcePostureChecks, p.SourcePostureChecks)
|
||||
return c
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to this policy
|
||||
func (p *Policy) EventMeta() map[string]any {
|
||||
return map[string]any{"name": p.Name}
|
||||
}
|
||||
|
||||
// UpgradeAndFix different version of policies to latest version
|
||||
func (p *Policy) UpgradeAndFix() {
|
||||
for _, r := range p.Rules {
|
||||
// start migrate from version v0.20.3
|
||||
if r.Protocol == "" {
|
||||
r.Protocol = PolicyRuleProtocolALL
|
||||
}
|
||||
if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional {
|
||||
r.Bidirectional = true
|
||||
}
|
||||
// -- v0.20.4
|
||||
}
|
||||
}
|
||||
|
||||
// ruleGroups returns a list of all groups referenced in the policy's rules,
|
||||
// including sources and destinations.
|
||||
func (p *Policy) ruleGroups() []string {
|
||||
groups := make([]string, 0)
|
||||
for _, rule := range p.Rules {
|
||||
groups = append(groups, rule.Sources...)
|
||||
groups = append(groups, rule.Destinations...)
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
|
||||
// FirewallRule is a rule of the firewall.
|
||||
type FirewallRule struct {
|
||||
// PeerIP of the peer
|
||||
PeerIP string
|
||||
|
||||
// Direction of the traffic
|
||||
Direction int
|
||||
|
||||
// Action of the traffic
|
||||
Action string
|
||||
|
||||
// Protocol of the traffic
|
||||
Protocol string
|
||||
|
||||
// Port of the traffic
|
||||
Port string
|
||||
}
|
||||
|
||||
// getPeerConnectionResources for a given peer
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
|
||||
|
||||
if rule.Bidirectional {
|
||||
if peerInSources {
|
||||
generateResources(rule, destinationPeers, firewallRuleDirectionIN)
|
||||
}
|
||||
if peerInDestinations {
|
||||
generateResources(rule, sourcePeers, firewallRuleDirectionOUT)
|
||||
}
|
||||
}
|
||||
|
||||
if peerInSources {
|
||||
generateResources(rule, destinationPeers, firewallRuleDirectionOUT)
|
||||
}
|
||||
|
||||
if peerInDestinations {
|
||||
generateResources(rule, sourcePeers, firewallRuleDirectionIN)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return getAccumulatedResources()
|
||||
}
|
||||
|
||||
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
|
||||
//
|
||||
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
|
||||
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
|
||||
// generated. The accumulator function returns the result of all the generator calls.
|
||||
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
|
||||
rulesExists := make(map[string]struct{})
|
||||
peersExists := make(map[string]struct{})
|
||||
rules := make([]*FirewallRule, 0)
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
|
||||
all, err := a.GetGroupAll()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
|
||||
all = &nbgroup.Group{}
|
||||
}
|
||||
|
||||
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
|
||||
isAll := (len(all.Peers) - 1) == len(groupPeers)
|
||||
for _, peer := range groupPeers {
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := peersExists[peer.ID]; !ok {
|
||||
peers = append(peers, peer)
|
||||
peersExists[peer.ID] = struct{}{}
|
||||
}
|
||||
|
||||
fr := FirewallRule{
|
||||
PeerIP: peer.IP.String(),
|
||||
Direction: direction,
|
||||
Action: string(rule.Action),
|
||||
Protocol: string(rule.Protocol),
|
||||
}
|
||||
|
||||
if isAll {
|
||||
fr.PeerIP = "0.0.0.0"
|
||||
}
|
||||
|
||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
|
||||
if _, ok := rulesExists[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
rulesExists[ruleID] = struct{}{}
|
||||
|
||||
if len(rule.Ports) == 0 {
|
||||
rules = append(rules, &fr)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, port := range rule.Ports {
|
||||
pr := fr // clone rule and add set new port
|
||||
pr.Port = port
|
||||
rules = append(rules, &pr)
|
||||
}
|
||||
}
|
||||
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
return peers, rules
|
||||
}
|
||||
}
|
||||
|
||||
// GetPolicy from the store
|
||||
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -353,15 +31,15 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
|
||||
return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID)
|
||||
}
|
||||
|
||||
// SavePolicy in the store
|
||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) {
|
||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -378,7 +56,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
var updateAccountPeers bool
|
||||
var action = activity.PolicyAdded
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -388,7 +66,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -398,7 +76,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
saveFunc = transaction.SavePolicy
|
||||
}
|
||||
|
||||
return saveFunc(ctx, LockingStrengthUpdate, policy)
|
||||
return saveFunc(ctx, store.LockingStrengthUpdate, policy)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -418,7 +96,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -431,11 +109,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var policy *Policy
|
||||
var policy *types.Policy
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -445,11 +123,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID)
|
||||
return transaction.DeletePolicy(ctx, store.LockingStrengthUpdate, accountID, policyID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -465,8 +143,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
}
|
||||
|
||||
// ListPolicies from the store.
|
||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -479,13 +157,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) {
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
|
||||
if isUpdate {
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -494,7 +172,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account
|
||||
return false, nil
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups())
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -504,13 +182,13 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account
|
||||
}
|
||||
}
|
||||
|
||||
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
|
||||
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
}
|
||||
|
||||
// validatePolicy validates the policy and its rules.
|
||||
func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error {
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
||||
if policy.ID != "" {
|
||||
_, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
|
||||
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -519,12 +197,12 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po
|
||||
policy.AccountID = accountID
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups())
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks)
|
||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -548,84 +226,6 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAllPeersFromGroups for given peer ID and list of groups
|
||||
//
|
||||
// Returns a list of peers from specified groups that pass specified posture checks
|
||||
// and a boolean indicating if the supplied peer ID exists within these groups.
|
||||
//
|
||||
// Important: Posture checks are applicable only to source group peers,
|
||||
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
||||
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
peerInGroups := false
|
||||
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
|
||||
for _, g := range groups {
|
||||
group, ok := a.Groups[g]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, p := range group.Peers {
|
||||
peer, ok := a.Peers[p]
|
||||
if !ok || peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// validate the peer based on policy posture checks applied
|
||||
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
if !isValid {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peer.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if peer.ID == peerID {
|
||||
peerInGroups = true
|
||||
continue
|
||||
}
|
||||
|
||||
filteredPeers = append(filteredPeers, peer)
|
||||
}
|
||||
}
|
||||
return filteredPeers, peerInGroups
|
||||
}
|
||||
|
||||
// validatePostureChecksOnPeer validates the posture checks on a peer
|
||||
func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool {
|
||||
peer, ok := a.Peers[peerID]
|
||||
if !ok && peer == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, postureChecksID := range sourcePostureChecksID {
|
||||
postureChecks := a.getPostureChecks(postureChecksID)
|
||||
if postureChecks == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, check := range postureChecks.GetChecks() {
|
||||
isValid, err := check.Check(ctx, *peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
|
||||
}
|
||||
if !isValid {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
||||
for _, postureChecks := range a.PostureChecks {
|
||||
if postureChecks.ID == postureChecksID {
|
||||
return postureChecks
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||
func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
|
||||
validIDs := make([]string, 0, len(postureChecksIds))
|
||||
@@ -651,7 +251,7 @@ func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []str
|
||||
}
|
||||
|
||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
||||
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||
func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
@@ -13,10 +13,11 @@ import (
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peerA": {
|
||||
ID: "peerA",
|
||||
@@ -87,21 +88,21 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
Policies: []*Policy{
|
||||
Policies: []*types.Policy{
|
||||
{
|
||||
ID: "RuleDefault",
|
||||
Name: "Default",
|
||||
Description: "This is a default rule that allows connections between all the resources",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "RuleDefault",
|
||||
Name: "Default",
|
||||
Description: "This is a default rule that allows connections between all the resources",
|
||||
Bidirectional: true,
|
||||
Enabled: true,
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolALL,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Sources: []string{
|
||||
"GroupAll",
|
||||
},
|
||||
@@ -116,15 +117,15 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
Name: "Swarm",
|
||||
Description: "No description",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "RuleSwarm",
|
||||
Name: "Swarm",
|
||||
Description: "No description",
|
||||
Bidirectional: true,
|
||||
Enabled: true,
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolALL,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Sources: []string{
|
||||
"GroupSwarm",
|
||||
"GroupAll",
|
||||
@@ -145,14 +146,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
t.Run("check that all peers get map", func(t *testing.T) {
|
||||
for _, p := range account.Peers {
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers)
|
||||
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
|
||||
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check first peer map details", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers)
|
||||
assert.Len(t, peers, 7)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
@@ -160,45 +161,45 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
assert.Contains(t, peers, account.Peers["peerE"])
|
||||
assert.Contains(t, peers, account.Peers["peerF"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.14.88",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.14.88",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
@@ -206,14 +207,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
{
|
||||
PeerIP: "100.65.62.5",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.62.5",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
@@ -221,14 +222,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
{
|
||||
PeerIP: "100.65.32.206",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.32.206",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
@@ -236,14 +237,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
{
|
||||
PeerIP: "100.65.250.202",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.250.202",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
@@ -251,14 +252,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
{
|
||||
PeerIP: "100.65.13.186",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.13.186",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
@@ -266,14 +267,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
{
|
||||
PeerIP: "100.65.29.55",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.29.55",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
@@ -289,7 +290,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peerA": {
|
||||
ID: "peerA",
|
||||
@@ -332,21 +333,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
Policies: []*Policy{
|
||||
Policies: []*types.Policy{
|
||||
{
|
||||
ID: "RuleDefault",
|
||||
Name: "Default",
|
||||
Description: "This is a default rule that allows connections between all the resources",
|
||||
Enabled: false,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "RuleDefault",
|
||||
Name: "Default",
|
||||
Description: "This is a default rule that allows connections between all the resources",
|
||||
Bidirectional: true,
|
||||
Enabled: false,
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolALL,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Sources: []string{
|
||||
"GroupAll",
|
||||
},
|
||||
@@ -361,15 +362,15 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
Name: "Swarm",
|
||||
Description: "No description",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "RuleSwarm",
|
||||
Name: "Swarm",
|
||||
Description: "No description",
|
||||
Bidirectional: true,
|
||||
Enabled: true,
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolALL,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Sources: []string{
|
||||
"GroupSwarm",
|
||||
},
|
||||
@@ -388,20 +389,20 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("check first peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
@@ -416,20 +417,20 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check second peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.80.39",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.80.39",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
@@ -446,13 +447,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
account.Policies[1].Rules[0].Bidirectional = false
|
||||
|
||||
t.Run("check first peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
@@ -467,13 +468,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check second peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.80.39",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
@@ -489,7 +490,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peerA": {
|
||||
ID: "peerA",
|
||||
@@ -630,17 +631,17 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
account.Policies = append(account.Policies, &Policy{
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "PolicyPostureChecks",
|
||||
Name: "",
|
||||
Description: "This is the policy with posture checks applied",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "RuleSwarm",
|
||||
Name: "Swarm",
|
||||
Enabled: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Destinations: []string{
|
||||
"GroupSwarm",
|
||||
},
|
||||
@@ -648,7 +649,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
"GroupAll",
|
||||
},
|
||||
Bidirectional: false,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
Ports: []string{"80"},
|
||||
},
|
||||
},
|
||||
@@ -664,7 +665,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
||||
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
||||
// will establish a connection with all source peers satisfying the NB posture check.
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -674,13 +675,13 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, 1)
|
||||
expectedFirewallRules := []*FirewallRule{
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
@@ -690,7 +691,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -700,7 +701,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -715,19 +716,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
||||
|
||||
@@ -742,14 +743,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
||||
assert.Len(t, peers, 3)
|
||||
assert.Len(t, firewallRules, 3)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers)
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers)
|
||||
assert.Len(t, peers, 5)
|
||||
// assert peers from Group Swarm
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
@@ -760,45 +761,45 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
// assert peers from Group All
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
expectedFirewallRules := []*FirewallRule{
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "100.65.62.5",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.32.206",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.13.186",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.29.55",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.62.5",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
@@ -809,8 +810,8 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func sortFunc() func(a *FirewallRule, b *FirewallRule) int {
|
||||
return func(a, b *FirewallRule) int {
|
||||
func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int {
|
||||
return func(a, b *types.FirewallRule) int {
|
||||
// Concatenate PeerIP and Direction as string for comparison
|
||||
aStr := a.PeerIP + fmt.Sprintf("%d", a.Direction)
|
||||
bStr := b.PeerIP + fmt.Sprintf("%d", b.Direction)
|
||||
@@ -858,9 +859,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
var policyWithGroupRulesNoPeers *Policy
|
||||
var policyWithDestinationPeersOnly *Policy
|
||||
var policyWithSourceAndDestinationPeers *Policy
|
||||
var policyWithGroupRulesNoPeers *types.Policy
|
||||
var policyWithDestinationPeersOnly *types.Policy
|
||||
var policyWithSourceAndDestinationPeers *types.Policy
|
||||
|
||||
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
||||
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
|
||||
@@ -870,16 +871,16 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupC"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -901,17 +902,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -933,17 +934,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupC"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -965,16 +966,16 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
@@ -12,10 +12,12 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -28,7 +30,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
|
||||
}
|
||||
|
||||
// SavePostureChecks saves a posture check.
|
||||
@@ -36,7 +38,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -53,7 +55,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
var isUpdate = postureChecks.ID != ""
|
||||
var action = activity.PostureCheckCreated
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -64,7 +66,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -72,7 +74,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
postureChecks.AccountID = accountID
|
||||
return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks)
|
||||
return transaction.SavePostureChecks(ctx, store.LockingStrengthUpdate, postureChecks)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -92,7 +94,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -107,8 +109,8 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
|
||||
var postureChecks *posture.Checks
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -117,11 +119,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID)
|
||||
return transaction.DeletePostureChecks(ctx, store.LockingStrengthUpdate, accountID, postureChecksID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -134,7 +136,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
|
||||
// ListPostureChecks returns a list of posture checks.
|
||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -147,11 +149,11 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
|
||||
peerPostureChecks := make(map[string]*posture.Checks)
|
||||
|
||||
if len(account.PostureChecks) == 0 {
|
||||
@@ -172,15 +174,15 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID s
|
||||
}
|
||||
|
||||
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups())
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -195,21 +197,21 @@ func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, a
|
||||
}
|
||||
|
||||
// validatePostureChecks validates the posture checks.
|
||||
func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error {
|
||||
func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error {
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
}
|
||||
|
||||
// If the posture check already has an ID, verify its existence in the store.
|
||||
if postureChecks.ID != "" {
|
||||
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
|
||||
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// For new posture checks, ensure no duplicates by name.
|
||||
checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -226,7 +228,7 @@ func validatePostureChecks(ctx context.Context, transaction Store, accountID str
|
||||
}
|
||||
|
||||
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
|
||||
func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
|
||||
func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
|
||||
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -237,7 +239,7 @@ func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, pee
|
||||
}
|
||||
|
||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||
postureCheck := account.getPostureChecks(sourcePostureCheckID)
|
||||
postureCheck := account.GetPostureChecks(sourcePostureCheckID)
|
||||
if postureCheck == nil {
|
||||
return errors.New("failed to add policy posture checks: posture checks not found")
|
||||
}
|
||||
@@ -248,7 +250,7 @@ func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, pee
|
||||
}
|
||||
|
||||
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
||||
func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) {
|
||||
func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
@@ -270,8 +272,8 @@ func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy)
|
||||
}
|
||||
|
||||
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
||||
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
)
|
||||
@@ -92,17 +94,17 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, error) {
|
||||
accountID := "testingAccount"
|
||||
domain := "example.com"
|
||||
|
||||
admin := &User{
|
||||
admin := &types.User{
|
||||
Id: adminUserID,
|
||||
Role: UserRoleAdmin,
|
||||
Role: types.UserRoleAdmin,
|
||||
}
|
||||
user := &User{
|
||||
user := &types.User{
|
||||
Id: regularUserID,
|
||||
Role: UserRoleUser,
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
@@ -209,15 +211,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
policy := &Policy{
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheckB.ID},
|
||||
@@ -312,15 +314,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
|
||||
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupC"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheckB.ID},
|
||||
@@ -356,15 +358,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
})
|
||||
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheckB.ID},
|
||||
@@ -395,15 +397,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
// Updating linked client posture check to policy where source has peers but destination does not,
|
||||
// should trigger account peers update and send peer update
|
||||
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheckB.ID},
|
||||
@@ -454,7 +456,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
AccountID: account.Id,
|
||||
Peers: []string{},
|
||||
}
|
||||
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
|
||||
err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, []*group.Group{groupA, groupB})
|
||||
require.NoError(t, err, "failed to save groups")
|
||||
|
||||
postureCheckA := &posture.Checks{
|
||||
@@ -477,9 +479,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB)
|
||||
require.NoError(t, err, "failed to save postureCheckB")
|
||||
|
||||
policy := &Policy{
|
||||
policy := &types.Policy{
|
||||
AccountID: account.Id,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
@@ -534,7 +536,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
|
||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||
groupA.Peers = []string{}
|
||||
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
|
||||
err = manager.Store.SaveGroup(context.Background(), store.LockingStrengthUpdate, groupA)
|
||||
require.NoError(t, err, "failed to save groups")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
|
||||
@@ -4,15 +4,12 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
@@ -21,33 +18,9 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// RouteFirewallRule a firewall rule applicable for a routed network.
|
||||
type RouteFirewallRule struct {
|
||||
// SourceRanges IP ranges of the routing peers.
|
||||
SourceRanges []string
|
||||
|
||||
// Action of the traffic when the rule is applicable
|
||||
Action string
|
||||
|
||||
// Destination a network prefix for the routed traffic
|
||||
Destination string
|
||||
|
||||
// Protocol of the traffic
|
||||
Protocol string
|
||||
|
||||
// Port of the traffic
|
||||
Port uint16
|
||||
|
||||
// PortRange represents the range of ports for a firewall rule
|
||||
PortRange RulePortRange
|
||||
|
||||
// isDynamic indicates whether the rule is for DNS routing
|
||||
IsDynamic bool
|
||||
}
|
||||
|
||||
// GetRoute gets a route object from account and route IDs
|
||||
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -56,11 +29,11 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||
}
|
||||
|
||||
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
|
||||
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID)
|
||||
}
|
||||
|
||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
|
||||
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
|
||||
// routes can have both peer and peer_groups
|
||||
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
|
||||
|
||||
@@ -364,7 +337,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -373,7 +346,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||
}
|
||||
|
||||
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||
@@ -404,244 +377,7 @@ func getPlaceholderIP() netip.Prefix {
|
||||
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
|
||||
}
|
||||
|
||||
// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
|
||||
func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
|
||||
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
|
||||
|
||||
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
for _, route := range enabledRoutes {
|
||||
// If no access control groups are specified, accept all traffic.
|
||||
if len(route.AccessControlGroups) == 0 {
|
||||
defaultPermit := getDefaultPermit(route)
|
||||
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
|
||||
continue
|
||||
}
|
||||
|
||||
distributionPeers := a.getDistributionGroupsPeers(route)
|
||||
|
||||
for _, accessGroup := range route.AccessControlGroups {
|
||||
policies := getAllRoutePoliciesFromGroups(a, []string{accessGroup})
|
||||
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
|
||||
routesFirewallRules = append(routesFirewallRules, rules...)
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
}
|
||||
|
||||
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
|
||||
var fwRules []*RouteFirewallRule
|
||||
for _, policy := range policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap)
|
||||
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, firewallRuleDirectionIN)
|
||||
fwRules = append(fwRules, rules...)
|
||||
}
|
||||
}
|
||||
return fwRules
|
||||
}
|
||||
|
||||
func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
|
||||
distPeersWithPolicy := make(map[string]struct{})
|
||||
for _, id := range rule.Sources {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
if pID == peerID {
|
||||
continue
|
||||
}
|
||||
_, distPeer := distributionPeers[pID]
|
||||
_, valid := validatedPeersMap[pID]
|
||||
if distPeer && valid {
|
||||
distPeersWithPolicy[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
|
||||
for pID := range distPeersWithPolicy {
|
||||
peer := a.Peers[pID]
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
distributionGroupPeers = append(distributionGroupPeers, peer)
|
||||
}
|
||||
return distributionGroupPeers
|
||||
}
|
||||
|
||||
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
|
||||
distPeers := make(map[string]struct{})
|
||||
for _, id := range route.Groups {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
distPeers[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return distPeers
|
||||
}
|
||||
|
||||
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
|
||||
var rules []*RouteFirewallRule
|
||||
|
||||
sources := []string{"0.0.0.0/0"}
|
||||
if route.Network.Addr().Is6() {
|
||||
sources = []string{"::/0"}
|
||||
}
|
||||
rule := RouteFirewallRule{
|
||||
SourceRanges: sources,
|
||||
Action: string(PolicyTrafficActionAccept),
|
||||
Destination: route.Network.String(),
|
||||
Protocol: string(PolicyRuleProtocolALL),
|
||||
IsDynamic: route.IsDynamic(),
|
||||
}
|
||||
|
||||
rules = append(rules, &rule)
|
||||
|
||||
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
|
||||
if route.IsDynamic() {
|
||||
ruleV6 := rule
|
||||
ruleV6.SourceRanges = []string{"::/0"}
|
||||
rules = append(rules, &ruleV6)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// getAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
|
||||
// and returns a list of policies that have rules with destinations matching the specified groups.
|
||||
func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
|
||||
routePolicies := make([]*Policy, 0)
|
||||
for _, groupID := range accessControlGroups {
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
for _, rule := range policy.Rules {
|
||||
exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool {
|
||||
return groupID == group.ID
|
||||
})
|
||||
if exist {
|
||||
routePolicies = append(routePolicies, policy)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return routePolicies
|
||||
}
|
||||
|
||||
// generateRouteFirewallRules generates a list of firewall rules for a given route.
|
||||
func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
|
||||
rulesExists := make(map[string]struct{})
|
||||
rules := make([]*RouteFirewallRule, 0)
|
||||
|
||||
sourceRanges := make([]string, 0, len(groupPeers))
|
||||
for _, peer := range groupPeers {
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP))
|
||||
}
|
||||
|
||||
baseRule := RouteFirewallRule{
|
||||
SourceRanges: sourceRanges,
|
||||
Action: string(rule.Action),
|
||||
Destination: route.Network.String(),
|
||||
Protocol: string(rule.Protocol),
|
||||
IsDynamic: route.IsDynamic(),
|
||||
}
|
||||
|
||||
// generate rule for port range
|
||||
if len(rule.Ports) == 0 {
|
||||
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
|
||||
} else {
|
||||
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
|
||||
|
||||
}
|
||||
|
||||
// TODO: generate IPv6 rules for dynamic routes
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// generateRuleIDBase generates the base rule ID for checking duplicates.
|
||||
func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string {
|
||||
return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(firewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action
|
||||
}
|
||||
|
||||
// generateRulesForPeer generates rules for a given peer based on ports and port ranges.
|
||||
func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
|
||||
rules := make([]*RouteFirewallRule, 0)
|
||||
|
||||
ruleIDBase := generateRuleIDBase(rule, baseRule)
|
||||
if len(rule.Ports) == 0 {
|
||||
if len(rule.PortRanges) == 0 {
|
||||
if _, ok := rulesExists[ruleIDBase]; !ok {
|
||||
rulesExists[ruleIDBase] = struct{}{}
|
||||
rules = append(rules, &baseRule)
|
||||
}
|
||||
} else {
|
||||
for _, portRange := range rule.PortRanges {
|
||||
ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End)
|
||||
if _, ok := rulesExists[ruleID]; !ok {
|
||||
rulesExists[ruleID] = struct{}{}
|
||||
pr := baseRule
|
||||
pr.PortRange = portRange
|
||||
rules = append(rules, &pr)
|
||||
}
|
||||
}
|
||||
}
|
||||
return rules
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// generateRulesWithPorts generates rules when specific ports are provided.
|
||||
func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
|
||||
rules := make([]*RouteFirewallRule, 0)
|
||||
ruleIDBase := generateRuleIDBase(rule, baseRule)
|
||||
|
||||
for _, port := range rule.Ports {
|
||||
ruleID := ruleIDBase + port
|
||||
if _, ok := rulesExists[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
rulesExists[ruleID] = struct{}{}
|
||||
|
||||
pr := baseRule
|
||||
p, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
pr.Port = uint16(p)
|
||||
rules = append(rules, &pr)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFirewallRule {
|
||||
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
|
||||
result := make([]*proto.RouteFirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
@@ -660,7 +396,7 @@ func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFir
|
||||
|
||||
// getProtoDirection converts the direction to proto.RuleDirection.
|
||||
func getProtoDirection(direction int) proto.RuleDirection {
|
||||
if direction == firewallRuleDirectionOUT {
|
||||
if direction == types.FirewallRuleDirectionOUT {
|
||||
return proto.RuleDirection_OUT
|
||||
}
|
||||
return proto.RuleDirection_IN
|
||||
@@ -668,7 +404,7 @@ func getProtoDirection(direction int) proto.RuleDirection {
|
||||
|
||||
// getProtoAction converts the action to proto.RuleAction.
|
||||
func getProtoAction(action string) proto.RuleAction {
|
||||
if action == string(PolicyTrafficActionDrop) {
|
||||
if action == string(types.PolicyTrafficActionDrop) {
|
||||
return proto.RuleAction_DROP
|
||||
}
|
||||
return proto.RuleAction_ACCEPT
|
||||
@@ -676,14 +412,14 @@ func getProtoAction(action string) proto.RuleAction {
|
||||
|
||||
// getProtoProtocol converts the protocol to proto.RuleProtocol.
|
||||
func getProtoProtocol(protocol string) proto.RuleProtocol {
|
||||
switch PolicyRuleProtocolType(protocol) {
|
||||
case PolicyRuleProtocolALL:
|
||||
switch types.PolicyRuleProtocolType(protocol) {
|
||||
case types.PolicyRuleProtocolALL:
|
||||
return proto.RuleProtocol_ALL
|
||||
case PolicyRuleProtocolTCP:
|
||||
case types.PolicyRuleProtocolTCP:
|
||||
return proto.RuleProtocol_TCP
|
||||
case PolicyRuleProtocolUDP:
|
||||
case types.PolicyRuleProtocolUDP:
|
||||
return proto.RuleProtocol_UDP
|
||||
case PolicyRuleProtocolICMP:
|
||||
case types.PolicyRuleProtocolICMP:
|
||||
return proto.RuleProtocol_ICMP
|
||||
default:
|
||||
return proto.RuleProtocol_UNKNOWN
|
||||
@@ -691,7 +427,7 @@ func getProtoProtocol(protocol string) proto.RuleProtocol {
|
||||
}
|
||||
|
||||
// getProtoPortInfo converts the port info to proto.PortInfo.
|
||||
func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
|
||||
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
||||
var portInfo proto.PortInfo
|
||||
if rule.Port != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
||||
@@ -708,6 +444,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
|
||||
|
||||
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers
|
||||
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
|
||||
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool {
|
||||
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
||||
}
|
||||
|
||||
@@ -17,7 +17,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -1092,7 +1094,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id)
|
||||
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id)
|
||||
require.NoError(t, err)
|
||||
var groupHA1, groupHA2 *nbgroup.Group
|
||||
for _, group := range groups {
|
||||
@@ -1255,10 +1257,10 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
}
|
||||
|
||||
func createRouterStore(t *testing.T) (Store, error) {
|
||||
func createRouterStore(t *testing.T) (store.Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1267,7 +1269,7 @@ func createRouterStore(t *testing.T) (Store, error) {
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
|
||||
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) {
|
||||
t.Helper()
|
||||
|
||||
accountID := "testingAcc"
|
||||
@@ -1279,8 +1281,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ips := account.getTakenIPs()
|
||||
peer1IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
ips := account.GetTakenIPs()
|
||||
peer1IP, err := types.AllocatePeerIP(account.Network.Net, ips)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1306,8 +1308,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
|
||||
}
|
||||
account.Peers[peer1.ID] = peer1
|
||||
|
||||
ips = account.getTakenIPs()
|
||||
peer2IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
ips = account.GetTakenIPs()
|
||||
peer2IP, err := types.AllocatePeerIP(account.Network.Net, ips)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1333,8 +1335,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
|
||||
}
|
||||
account.Peers[peer2.ID] = peer2
|
||||
|
||||
ips = account.getTakenIPs()
|
||||
peer3IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
ips = account.GetTakenIPs()
|
||||
peer3IP, err := types.AllocatePeerIP(account.Network.Net, ips)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1360,8 +1362,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
|
||||
}
|
||||
account.Peers[peer3.ID] = peer3
|
||||
|
||||
ips = account.getTakenIPs()
|
||||
peer4IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
ips = account.GetTakenIPs()
|
||||
peer4IP, err := types.AllocatePeerIP(account.Network.Net, ips)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1387,8 +1389,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
|
||||
}
|
||||
account.Peers[peer4.ID] = peer4
|
||||
|
||||
ips = account.getTakenIPs()
|
||||
peer5IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
ips = account.GetTakenIPs()
|
||||
peer5IP, err := types.AllocatePeerIP(account.Network.Net, ips)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1491,7 +1493,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
peerKIp = "100.65.29.66"
|
||||
)
|
||||
|
||||
account := &Account{
|
||||
account := &types.Account{
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peerA": {
|
||||
ID: "peerA",
|
||||
@@ -1685,19 +1687,19 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
AccessControlGroups: []string{"route4"},
|
||||
},
|
||||
},
|
||||
Policies: []*Policy{
|
||||
Policies: []*types.Policy{
|
||||
{
|
||||
ID: "RuleRoute1",
|
||||
Name: "Route1",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "RuleRoute1",
|
||||
Name: "ruleRoute1",
|
||||
Bidirectional: true,
|
||||
Enabled: true,
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolALL,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Ports: []string{"80", "320"},
|
||||
Sources: []string{
|
||||
"dev",
|
||||
@@ -1712,15 +1714,15 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
ID: "RuleRoute2",
|
||||
Name: "Route2",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "RuleRoute2",
|
||||
Name: "ruleRoute2",
|
||||
Bidirectional: true,
|
||||
Enabled: true,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
PortRanges: []RulePortRange{
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
PortRanges: []types.RulePortRange{
|
||||
{
|
||||
Start: 80,
|
||||
End: 350,
|
||||
@@ -1742,14 +1744,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
ID: "RuleRoute4",
|
||||
Name: "RuleRoute4",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "RuleRoute4",
|
||||
Name: "RuleRoute4",
|
||||
Bidirectional: true,
|
||||
Enabled: true,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Ports: []string{"80"},
|
||||
Sources: []string{
|
||||
"restrictQA",
|
||||
@@ -1764,14 +1766,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
ID: "RuleRoute5",
|
||||
Name: "RuleRoute5",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: "RuleRoute5",
|
||||
Name: "RuleRoute5",
|
||||
Bidirectional: true,
|
||||
Enabled: true,
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolALL,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Sources: []string{
|
||||
"unrestrictedQA",
|
||||
},
|
||||
@@ -1791,28 +1793,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
|
||||
t.Run("check applied policies for the route", func(t *testing.T) {
|
||||
route1 := account.Routes["route1"]
|
||||
policies := getAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
|
||||
policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
|
||||
assert.Len(t, policies, 1)
|
||||
|
||||
route2 := account.Routes["route2"]
|
||||
policies = getAllRoutePoliciesFromGroups(account, route2.AccessControlGroups)
|
||||
policies = types.GetAllRoutePoliciesFromGroups(account, route2.AccessControlGroups)
|
||||
assert.Len(t, policies, 1)
|
||||
|
||||
route3 := account.Routes["route3"]
|
||||
policies = getAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
|
||||
policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
|
||||
assert.Len(t, policies, 0)
|
||||
})
|
||||
|
||||
t.Run("check peer routes firewall rules", func(t *testing.T) {
|
||||
routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
|
||||
routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 4)
|
||||
|
||||
expectedRoutesFirewallRules := []*RouteFirewallRule{
|
||||
expectedRoutesFirewallRules := []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(AllowedIPsFormat, peerCIp),
|
||||
fmt.Sprintf(AllowedIPsFormat, peerHIp),
|
||||
fmt.Sprintf(AllowedIPsFormat, peerBIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.0.0/16",
|
||||
@@ -1821,9 +1823,9 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(AllowedIPsFormat, peerCIp),
|
||||
fmt.Sprintf(AllowedIPsFormat, peerHIp),
|
||||
fmt.Sprintf(AllowedIPsFormat, peerBIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.0.0/16",
|
||||
@@ -1831,10 +1833,10 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
Port: 320,
|
||||
},
|
||||
}
|
||||
additionalFirewallRule := []*RouteFirewallRule{
|
||||
additionalFirewallRule := []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(AllowedIPsFormat, peerJIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.10.0/16",
|
||||
@@ -1843,7 +1845,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(AllowedIPsFormat, peerKIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.10.0/16",
|
||||
@@ -1854,21 +1856,21 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...)))
|
||||
|
||||
// peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
|
||||
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
|
||||
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 2)
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
|
||||
|
||||
// peerE is a single routing peer for route 2 and route 3
|
||||
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
|
||||
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 3)
|
||||
|
||||
expectedRoutesFirewallRules = []*RouteFirewallRule{
|
||||
expectedRoutesFirewallRules = []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
|
||||
Action: "accept",
|
||||
Destination: existingNetwork.String(),
|
||||
Protocol: "tcp",
|
||||
PortRange: RulePortRange{Start: 80, End: 350},
|
||||
PortRange: types.RulePortRange{Start: 80, End: 350},
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{"0.0.0.0/0"},
|
||||
@@ -1888,14 +1890,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
|
||||
|
||||
// peerC is part of route1 distribution groups but should not receive the routes firewall rules
|
||||
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
|
||||
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 0)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// orderList is a helper function to sort a list of strings
|
||||
func orderRuleSourceRanges(ruleList []*RouteFirewallRule) []*RouteFirewallRule {
|
||||
func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule {
|
||||
for _, rule := range ruleList {
|
||||
sort.Strings(rule.SourceRanges)
|
||||
}
|
||||
|
||||
@@ -2,34 +2,16 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"hash/fnv"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
const (
|
||||
// SetupKeyReusable is a multi-use key (can be used for multiple machines)
|
||||
SetupKeyReusable SetupKeyType = "reusable"
|
||||
// SetupKeyOneOff is a single use key (can be used only once)
|
||||
SetupKeyOneOff SetupKeyType = "one-off"
|
||||
|
||||
// DefaultSetupKeyDuration = 1 month
|
||||
DefaultSetupKeyDuration = 24 * 30 * time.Hour
|
||||
// DefaultSetupKeyName is a default name of the default setup key
|
||||
DefaultSetupKeyName = "Default key"
|
||||
// SetupKeyUnlimitedUsage indicates an unlimited usage of a setup key
|
||||
SetupKeyUnlimitedUsage = 0
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -67,169 +49,14 @@ type SetupKeyUpdateOperation struct {
|
||||
Values []string
|
||||
}
|
||||
|
||||
// SetupKeyType is the type of setup key
|
||||
type SetupKeyType string
|
||||
|
||||
// SetupKey represents a pre-authorized key used to register machines (peers)
|
||||
type SetupKey struct {
|
||||
Id string
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
Key string
|
||||
KeySecret string
|
||||
Name string
|
||||
Type SetupKeyType
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime:false"`
|
||||
// Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes)
|
||||
Revoked bool
|
||||
// UsedTimes indicates how many times the key was used
|
||||
UsedTimes int
|
||||
// LastUsed last time the key was used for peer registration
|
||||
LastUsed time.Time
|
||||
// AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register
|
||||
AutoGroups []string `gorm:"serializer:json"`
|
||||
// UsageLimit indicates the number of times this key can be used to enroll a machine.
|
||||
// The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int
|
||||
// Ephemeral indicate if the peers will be ephemeral or not
|
||||
Ephemeral bool
|
||||
}
|
||||
|
||||
// Copy copies SetupKey to a new object
|
||||
func (key *SetupKey) Copy() *SetupKey {
|
||||
autoGroups := make([]string, len(key.AutoGroups))
|
||||
copy(autoGroups, key.AutoGroups)
|
||||
if key.UpdatedAt.IsZero() {
|
||||
key.UpdatedAt = key.CreatedAt
|
||||
}
|
||||
return &SetupKey{
|
||||
Id: key.Id,
|
||||
AccountID: key.AccountID,
|
||||
Key: key.Key,
|
||||
KeySecret: key.KeySecret,
|
||||
Name: key.Name,
|
||||
Type: key.Type,
|
||||
CreatedAt: key.CreatedAt,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
UpdatedAt: key.UpdatedAt,
|
||||
Revoked: key.Revoked,
|
||||
UsedTimes: key.UsedTimes,
|
||||
LastUsed: key.LastUsed,
|
||||
AutoGroups: autoGroups,
|
||||
UsageLimit: key.UsageLimit,
|
||||
Ephemeral: key.Ephemeral,
|
||||
}
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to the setup key
|
||||
func (key *SetupKey) EventMeta() map[string]any {
|
||||
return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret}
|
||||
}
|
||||
|
||||
// hiddenKey returns the Key value hidden with "*" and a 5 character prefix.
|
||||
// E.g., "831F6*******************************"
|
||||
func hiddenKey(key string, length int) string {
|
||||
prefix := key[0:5]
|
||||
if length > utf8.RuneCountInString(key) {
|
||||
length = utf8.RuneCountInString(key) - len(prefix)
|
||||
}
|
||||
return prefix + strings.Repeat("*", length)
|
||||
}
|
||||
|
||||
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
|
||||
func (key *SetupKey) IncrementUsage() *SetupKey {
|
||||
c := key.Copy()
|
||||
c.UsedTimes++
|
||||
c.LastUsed = time.Now().UTC()
|
||||
return c
|
||||
}
|
||||
|
||||
// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to
|
||||
func (key *SetupKey) IsValid() bool {
|
||||
return !key.IsRevoked() && !key.IsExpired() && !key.IsOverUsed()
|
||||
}
|
||||
|
||||
// IsRevoked if key was revoked
|
||||
func (key *SetupKey) IsRevoked() bool {
|
||||
return key.Revoked
|
||||
}
|
||||
|
||||
// IsExpired if key was expired
|
||||
func (key *SetupKey) IsExpired() bool {
|
||||
if key.ExpiresAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(key.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage.
|
||||
func (key *SetupKey) IsOverUsed() bool {
|
||||
limit := key.UsageLimit
|
||||
if key.Type == SetupKeyOneOff {
|
||||
limit = 1
|
||||
}
|
||||
return limit > 0 && key.UsedTimes >= limit
|
||||
}
|
||||
|
||||
// GenerateSetupKey generates a new setup key
|
||||
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string,
|
||||
usageLimit int, ephemeral bool) (*SetupKey, string) {
|
||||
key := strings.ToUpper(uuid.New().String())
|
||||
limit := usageLimit
|
||||
if t == SetupKeyOneOff {
|
||||
limit = 1
|
||||
}
|
||||
|
||||
expiresAt := time.Time{}
|
||||
if validFor != 0 {
|
||||
expiresAt = time.Now().UTC().Add(validFor)
|
||||
}
|
||||
|
||||
hashedKey := sha256.Sum256([]byte(key))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
return &SetupKey{
|
||||
Id: strconv.Itoa(int(Hash(key))),
|
||||
Key: encodedHashedKey,
|
||||
KeySecret: hiddenKey(key, 4),
|
||||
Name: name,
|
||||
Type: t,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: expiresAt,
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
Revoked: false,
|
||||
UsedTimes: 0,
|
||||
AutoGroups: autoGroups,
|
||||
UsageLimit: limit,
|
||||
Ephemeral: ephemeral,
|
||||
}, key
|
||||
}
|
||||
|
||||
// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration
|
||||
func GenerateDefaultSetupKey() (*SetupKey, string) {
|
||||
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{},
|
||||
SetupKeyUnlimitedUsage, false)
|
||||
}
|
||||
|
||||
func Hash(s string) uint32 {
|
||||
h := fnv.New32a()
|
||||
_, err := h.Write([]byte(s))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return h.Sum32()
|
||||
}
|
||||
|
||||
// CreateSetupKey generates a new setup key with a given name, type, list of groups IDs to auto-assign to peers registered with this key,
|
||||
// and adds it to the specified account. A list of autoGroups IDs can be empty.
|
||||
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
|
||||
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -242,22 +69,22 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var setupKey *SetupKey
|
||||
var setupKey *types.SetupKey
|
||||
var plainKey string
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
|
||||
setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
|
||||
setupKey.AccountID = accountID
|
||||
|
||||
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey)
|
||||
return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, setupKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -278,7 +105,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
// Due to the unique nature of a SetupKey certain properties must not be overwritten
|
||||
// (e.g. the key itself, creation date, ID, etc).
|
||||
// These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key.
|
||||
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
|
||||
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *types.SetupKey, userID string) (*types.SetupKey, error) {
|
||||
if keyToSave == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
|
||||
}
|
||||
@@ -286,7 +113,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -299,16 +126,16 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var oldKey *SetupKey
|
||||
var newKey *SetupKey
|
||||
var oldKey *types.SetupKey
|
||||
var newKey *types.SetupKey
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
|
||||
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -323,13 +150,13 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
newKey.Revoked = keyToSave.Revoked
|
||||
newKey.UpdatedAt = time.Now().UTC()
|
||||
|
||||
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
|
||||
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
|
||||
addedGroups := util.Difference(newKey.AutoGroups, oldKey.AutoGroups)
|
||||
removedGroups := util.Difference(oldKey.AutoGroups, newKey.AutoGroups)
|
||||
|
||||
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey)
|
||||
return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, newKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -347,8 +174,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
}
|
||||
|
||||
// ListSetupKeys returns a list of all setup keys of the account
|
||||
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -361,12 +188,12 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -379,7 +206,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -394,7 +221,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
|
||||
// DeleteSetupKey removes the setup key from the account
|
||||
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -407,15 +234,15 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var deletedSetupKey *SetupKey
|
||||
var deletedSetupKey *types.SetupKey
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID)
|
||||
return transaction.DeleteSetupKey(ctx, store.LockingStrengthUpdate, accountID, keyID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -426,8 +253,8 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs)
|
||||
func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -447,11 +274,11 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI
|
||||
}
|
||||
|
||||
// prepareSetupKeyEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() {
|
||||
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction store.Store, accountID, userID string, addedGroups, removedGroups []string, key *types.SetupKey) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
|
||||
return nil
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
@@ -49,15 +50,15 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
expiresIn := time.Hour
|
||||
keyName := "my-test-key"
|
||||
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{},
|
||||
SetupKeyUnlimitedUsage, userID, false)
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, types.SetupKeyReusable, expiresIn, []string{},
|
||||
types.SetupKeyUnlimitedUsage, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
autoGroups := []string{"group_1", "group_2"}
|
||||
revoked := true
|
||||
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &types.SetupKey{
|
||||
Id: key.Id,
|
||||
Revoked: revoked,
|
||||
AutoGroups: autoGroups,
|
||||
@@ -85,7 +86,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
|
||||
// saving setup key with All group assigned to auto groups should return error
|
||||
autoGroups = append(autoGroups, groupAll.ID)
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, &types.SetupKey{
|
||||
Id: key.Id,
|
||||
Revoked: revoked,
|
||||
AutoGroups: autoGroups,
|
||||
@@ -167,8 +168,8 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
|
||||
for _, tCase := range []testCase{testCase1, testCase2, testCase3} {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
|
||||
tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false)
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, types.SetupKeyReusable, expiresIn,
|
||||
tCase.expectedGroups, types.SetupKeyUnlimitedUsage, userID, false)
|
||||
|
||||
if tCase.expectedFailure {
|
||||
if err == nil {
|
||||
@@ -182,7 +183,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes,
|
||||
tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))),
|
||||
tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(types.Hash(key.Key))),
|
||||
tCase.expectedUpdatedAt, tCase.expectedGroups, false)
|
||||
|
||||
// check the corresponding events that should have been generated
|
||||
@@ -210,7 +211,7 @@ func TestGetSetupKeys(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
|
||||
plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", types.SetupKeyReusable, time.Hour, nil, types.SetupKeyUnlimitedUsage, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -258,10 +259,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) {
|
||||
expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour)
|
||||
var expectedAutoGroups []string
|
||||
|
||||
key, plainKey := GenerateDefaultSetupKey()
|
||||
key, plainKey := types.GenerateDefaultSetupKey()
|
||||
|
||||
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
|
||||
expectedExpiresAt, strconv.Itoa(int(Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true)
|
||||
expectedExpiresAt, strconv.Itoa(int(types.Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true)
|
||||
|
||||
}
|
||||
|
||||
@@ -275,48 +276,48 @@ func TestGenerateSetupKey(t *testing.T) {
|
||||
expectedUpdatedAt := time.Now().UTC()
|
||||
var expectedAutoGroups []string
|
||||
|
||||
key, plain := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
key, plain := types.GenerateSetupKey(expectedName, types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
|
||||
|
||||
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
|
||||
expectedExpiresAt, strconv.Itoa(int(Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true)
|
||||
expectedExpiresAt, strconv.Itoa(int(types.Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true)
|
||||
|
||||
}
|
||||
|
||||
func TestSetupKey_IsValid(t *testing.T) {
|
||||
validKey, _ := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
validKey, _ := types.GenerateSetupKey("valid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
|
||||
if !validKey.IsValid() {
|
||||
t.Errorf("expected key to be valid, got invalid %v", validKey)
|
||||
}
|
||||
|
||||
// expired
|
||||
expiredKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
expiredKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, -time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
|
||||
if expiredKey.IsValid() {
|
||||
t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey)
|
||||
}
|
||||
|
||||
// revoked
|
||||
revokedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
revokedKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
|
||||
revokedKey.Revoked = true
|
||||
if revokedKey.IsValid() {
|
||||
t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey)
|
||||
}
|
||||
|
||||
// overused
|
||||
overUsedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
overUsedKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
|
||||
overUsedKey.UsedTimes = 1
|
||||
if overUsedKey.IsValid() {
|
||||
t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey)
|
||||
}
|
||||
|
||||
// overused
|
||||
reusableKey, _ := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
reusableKey, _ := types.GenerateSetupKey("valid key", types.SetupKeyReusable, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
|
||||
reusableKey.UsedTimes = 99
|
||||
if !reusableKey.IsValid() {
|
||||
t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey)
|
||||
}
|
||||
}
|
||||
|
||||
func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string,
|
||||
func assertKey(t *testing.T, key *types.SetupKey, expectedName string, expectedRevoke bool, expectedType string,
|
||||
expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string,
|
||||
expectedUpdatedAt time.Time, expectedAutoGroups []string, expectHashedKey bool) {
|
||||
t.Helper()
|
||||
@@ -388,7 +389,7 @@ func isValidBase64SHA256(encodedKey string) bool {
|
||||
|
||||
func TestSetupKey_Copy(t *testing.T) {
|
||||
|
||||
key, _ := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
|
||||
key, _ := types.GenerateSetupKey("key name", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
|
||||
keyCopy := key.Copy()
|
||||
|
||||
assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id,
|
||||
@@ -406,15 +407,15 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
policy := &Policy{
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"group"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -426,7 +427,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
var setupKey *SetupKey
|
||||
var setupKey *types.SetupKey
|
||||
|
||||
// Creating setup key should not update account peers and not send peer update
|
||||
t.Run("creating setup key", func(t *testing.T) {
|
||||
@@ -436,7 +437,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", types.SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -477,7 +478,7 @@ func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", types.SetupKeyReusable, time.Hour, nil, types.SetupKeyUnlimitedUsage, userID, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// revoke the key
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -22,7 +23,7 @@ const storeFileName = "store.json"
|
||||
|
||||
// FileStore represents an account storage backed by a file persisted to disk
|
||||
type FileStore struct {
|
||||
Accounts map[string]*Account
|
||||
Accounts map[string]*types.Account
|
||||
SetupKeyID2AccountID map[string]string `json:"-"`
|
||||
PeerKeyID2AccountID map[string]string `json:"-"`
|
||||
PeerID2AccountID map[string]string `json:"-"`
|
||||
@@ -55,7 +56,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
if _, err := os.Stat(file); os.IsNotExist(err) {
|
||||
// create a new FileStore if previously didn't exist (e.g. first run)
|
||||
s := &FileStore{
|
||||
Accounts: make(map[string]*Account),
|
||||
Accounts: make(map[string]*types.Account),
|
||||
mux: sync.Mutex{},
|
||||
SetupKeyID2AccountID: make(map[string]string),
|
||||
PeerKeyID2AccountID: make(map[string]string),
|
||||
@@ -92,12 +93,12 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
|
||||
for accountID, account := range store.Accounts {
|
||||
if account.Settings == nil {
|
||||
account.Settings = &Settings{
|
||||
account.Settings = &types.Settings{
|
||||
PeerLoginExpirationEnabled: false,
|
||||
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
|
||||
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,7 +113,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
for _, user := range account.Users {
|
||||
store.UserID2AccountID[user.Id] = accountID
|
||||
if user.Issued == "" {
|
||||
user.Issued = UserIssuedAPI
|
||||
user.Issued = types.UserIssuedAPI
|
||||
account.Users[user.Id] = user
|
||||
}
|
||||
|
||||
@@ -122,7 +123,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if account.Domain != "" && account.DomainCategory == PrivateCategory &&
|
||||
if account.Domain != "" && account.DomainCategory == types.PrivateCategory &&
|
||||
account.IsDomainPrimaryAccount {
|
||||
store.PrivateDomain2AccountID[account.Domain] = accountID
|
||||
}
|
||||
@@ -134,13 +135,13 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
policy.UpgradeAndFix()
|
||||
}
|
||||
if account.Policies == nil {
|
||||
account.Policies = make([]*Policy, 0)
|
||||
account.Policies = make([]*types.Policy, 0)
|
||||
}
|
||||
|
||||
// for data migration. Can be removed once most base will be with labels
|
||||
existingLabels := account.getPeerDNSLabels()
|
||||
existingLabels := account.GetPeerDNSLabels()
|
||||
if len(existingLabels) != len(account.Peers) {
|
||||
addPeerLabelsToAccount(ctx, account, existingLabels)
|
||||
types.AddPeerLabelsToAccount(ctx, account, existingLabels)
|
||||
}
|
||||
|
||||
// TODO: delete this block after migration
|
||||
@@ -236,7 +237,7 @@ func (s *FileStore) persist(ctx context.Context, file string) error {
|
||||
}
|
||||
|
||||
// GetAllAccounts returns all accounts
|
||||
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
|
||||
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*types.Account) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
for _, a := range s.Accounts {
|
||||
@@ -257,6 +258,6 @@ func (s *FileStore) Close(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// GetStoreEngine returns FileStoreEngine
|
||||
func (s *FileStore) GetStoreEngine() StoreEngine {
|
||||
func (s *FileStore) GetStoreEngine() Engine {
|
||||
return FileStoreEngine
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
@@ -26,10 +25,14 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -50,7 +53,7 @@ type SqlStore struct {
|
||||
globalAccountLock sync.Mutex
|
||||
metrics telemetry.AppMetrics
|
||||
installationPK int
|
||||
storeEngine StoreEngine
|
||||
storeEngine Engine
|
||||
}
|
||||
|
||||
type installation struct {
|
||||
@@ -61,7 +64,7 @@ type installation struct {
|
||||
type migrationFunc func(*gorm.DB) error
|
||||
|
||||
// NewSqlStore creates a new SqlStore instance.
|
||||
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
sql, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -87,10 +90,10 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr
|
||||
return nil, fmt.Errorf("migrate: %w", err)
|
||||
}
|
||||
err = db.AutoMigrate(
|
||||
&SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{},
|
||||
&Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &nbgroup.Group{},
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networks.Network{}, &networks.NetworkRouter{}, &networks.NetworkResource{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migrate: %w", err)
|
||||
@@ -153,7 +156,7 @@ func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (u
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error {
|
||||
func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
elapsed := time.Since(start)
|
||||
@@ -203,7 +206,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error {
|
||||
}
|
||||
|
||||
// generateAccountSQLTypes generates the GORM compatible types for the account
|
||||
func generateAccountSQLTypes(account *Account) {
|
||||
func generateAccountSQLTypes(account *types.Account) {
|
||||
for _, key := range account.SetupKeys {
|
||||
account.SetupKeysG = append(account.SetupKeysG, *key)
|
||||
}
|
||||
@@ -240,7 +243,7 @@ func generateAccountSQLTypes(account *Account) {
|
||||
|
||||
// checkAccountDomainBeforeSave temporary method to troubleshoot an issue with domains getting blank
|
||||
func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, newDomain string) {
|
||||
var acc Account
|
||||
var acc types.Account
|
||||
var domain string
|
||||
result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).First(&domain)
|
||||
if result.Error != nil {
|
||||
@@ -254,7 +257,7 @@ func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteAccount(ctx context.Context, account *Account) error {
|
||||
func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) error {
|
||||
start := time.Now()
|
||||
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
@@ -335,14 +338,14 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
|
||||
accountCopy := Account{
|
||||
accountCopy := types.Account{
|
||||
Domain: domain,
|
||||
DomainCategory: category,
|
||||
IsDomainPrimaryAccount: isPrimaryDomain,
|
||||
}
|
||||
|
||||
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
|
||||
result := s.db.Model(&Account{}).
|
||||
result := s.db.Model(&types.Account{}).
|
||||
Select(fieldsToUpdate).
|
||||
Where(idQueryCondition, accountID).
|
||||
Updates(&accountCopy)
|
||||
@@ -404,8 +407,8 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
||||
|
||||
// SaveUsers saves the given list of users to the database.
|
||||
// It updates existing users if a conflict occurs.
|
||||
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
usersToSave := make([]User, 0, len(users))
|
||||
func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) error {
|
||||
usersToSave := make([]types.User, 0, len(users))
|
||||
for _, user := range users {
|
||||
user.AccountID = accountID
|
||||
for id, pat := range user.PATs {
|
||||
@@ -425,7 +428,7 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
}
|
||||
|
||||
// SaveUser saves the given user to the database.
|
||||
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
|
||||
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
|
||||
@@ -456,7 +459,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) {
|
||||
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -468,9 +471,9 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
|
||||
var accountID string
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("id").
|
||||
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||
strings.ToLower(domain), true, PrivateCategory,
|
||||
strings.ToLower(domain), true, types.PrivateCategory,
|
||||
).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -483,8 +486,8 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
||||
var key SetupKey
|
||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) {
|
||||
var key types.SetupKey
|
||||
result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -502,7 +505,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
|
||||
var token PersonalAccessToken
|
||||
var token types.PersonalAccessToken
|
||||
result := s.db.First(&token, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -515,8 +518,8 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
|
||||
return token.ID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) {
|
||||
var token PersonalAccessToken
|
||||
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) {
|
||||
var token types.PersonalAccessToken
|
||||
result := s.db.First(&token, idQueryCondition, tokenID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -530,13 +533,13 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
|
||||
var user User
|
||||
var user types.User
|
||||
result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
|
||||
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG))
|
||||
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATsG))
|
||||
for _, pat := range user.PATsG {
|
||||
user.PATs[pat.ID] = pat.Copy()
|
||||
}
|
||||
@@ -544,8 +547,8 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||
var user User
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
||||
var user types.User
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Preload(clause.Associations).First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
@@ -558,8 +561,8 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
|
||||
var users []*User
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
|
||||
var users []*types.User
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -586,8 +589,8 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
|
||||
var accounts []Account
|
||||
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
|
||||
var accounts []types.Account
|
||||
result := s.db.Find(&accounts)
|
||||
if result.Error != nil {
|
||||
return all
|
||||
@@ -602,7 +605,7 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
|
||||
return all
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
elapsed := time.Since(start)
|
||||
@@ -611,7 +614,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
||||
}
|
||||
}()
|
||||
|
||||
var account Account
|
||||
var account types.Account
|
||||
result := s.db.Model(&account).
|
||||
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
|
||||
Preload(clause.Associations).
|
||||
@@ -626,15 +629,15 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
||||
|
||||
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
|
||||
for i, policy := range account.Policies {
|
||||
var rules []*PolicyRule
|
||||
err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
var rules []*types.PolicyRule
|
||||
err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "rule not found")
|
||||
}
|
||||
account.Policies[i].Rules = rules
|
||||
}
|
||||
|
||||
account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG))
|
||||
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
||||
for _, key := range account.SetupKeysG {
|
||||
account.SetupKeys[key.Key] = key.Copy()
|
||||
}
|
||||
@@ -646,9 +649,9 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
||||
}
|
||||
account.PeersG = nil
|
||||
|
||||
account.Users = make(map[string]*User, len(account.UsersG))
|
||||
account.Users = make(map[string]*types.User, len(account.UsersG))
|
||||
for _, user := range account.UsersG {
|
||||
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs))
|
||||
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
|
||||
for _, pat := range user.PATsG {
|
||||
user.PATs[pat.ID] = pat.Copy()
|
||||
}
|
||||
@@ -677,8 +680,8 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
|
||||
var user User
|
||||
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) {
|
||||
var user types.User
|
||||
result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -694,7 +697,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
|
||||
return s.GetAccount(ctx, user.AccountID)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
|
||||
if result.Error != nil {
|
||||
@@ -711,7 +714,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
|
||||
return s.GetAccount(ctx, peer.AccountID)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) {
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.Select("account_id").First(&peer, keyQueryCondition, peerKey)
|
||||
if result.Error != nil {
|
||||
@@ -744,7 +747,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
||||
|
||||
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
var accountID string
|
||||
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||
result := s.db.Model(&types.User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
@@ -757,7 +760,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
|
||||
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
||||
var accountID string
|
||||
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
|
||||
result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.NewSetupKeyNotFoundError(setupKey)
|
||||
@@ -817,9 +820,9 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
||||
return labels, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
|
||||
var accountNetwork AccountNetwork
|
||||
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
|
||||
var accountNetwork types.AccountNetwork
|
||||
if err := s.db.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
@@ -841,9 +844,9 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
|
||||
return &peer, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
|
||||
var accountSettings AccountSettings
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) {
|
||||
var accountSettings types.AccountSettings
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "settings not found")
|
||||
}
|
||||
@@ -854,7 +857,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
var user User
|
||||
var user types.User
|
||||
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -892,7 +895,7 @@ func (s *SqlStore) Close(_ context.Context) error {
|
||||
}
|
||||
|
||||
// GetStoreEngine returns underlying store engine
|
||||
func (s *SqlStore) GetStoreEngine() StoreEngine {
|
||||
func (s *SqlStore) GetStoreEngine() Engine {
|
||||
return s.storeEngine
|
||||
}
|
||||
|
||||
@@ -984,8 +987,8 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore,
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
||||
var setupKey SetupKey
|
||||
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
|
||||
var setupKey types.SetupKey
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&setupKey, keyQueryCondition, key)
|
||||
if result.Error != nil {
|
||||
@@ -999,7 +1002,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||
result := s.db.Model(&SetupKey{}).
|
||||
result := s.db.Model(&types.SetupKey{}).
|
||||
Where(idQueryCondition, setupKeyID).
|
||||
Updates(map[string]interface{}{
|
||||
"used_times": gorm.Expr("used_times + 1"),
|
||||
@@ -1116,7 +1119,7 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
||||
@@ -1158,9 +1161,9 @@ func (s *SqlStore) GetDB() *gorm.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
|
||||
var accountDNSSettings AccountDNSSettings
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) {
|
||||
var accountDNSSettings types.AccountDNSSettings
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
|
||||
First(&accountDNSSettings, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -1175,7 +1178,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||
var accountID string
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
|
||||
Select("id").First(&accountID, idQueryCondition, id)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -1189,8 +1192,8 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
|
||||
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
|
||||
var account Account
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
|
||||
var account types.Account
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("domain", "domain_category").
|
||||
Where(idQueryCondition, accountID).First(&account)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -1297,8 +1300,8 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
|
||||
}
|
||||
|
||||
// GetAccountPolicies retrieves policies for an account.
|
||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||
var policies []*Policy
|
||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) {
|
||||
var policies []*types.Policy
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
@@ -1310,8 +1313,8 @@ func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingS
|
||||
}
|
||||
|
||||
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) {
|
||||
var policy *Policy
|
||||
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) {
|
||||
var policy *types.Policy
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
||||
First(&policy, accountAndIDQueryCondition, accountID, policyID)
|
||||
if err := result.Error; err != nil {
|
||||
@@ -1325,7 +1328,7 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
|
||||
func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error)
|
||||
@@ -1336,7 +1339,7 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
// SavePolicy saves a policy to the database.
|
||||
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
|
||||
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error {
|
||||
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy)
|
||||
if err := result.Error; err != nil {
|
||||
@@ -1348,7 +1351,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength,
|
||||
|
||||
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID)
|
||||
Delete(&types.Policy{}, accountAndIDQueryCondition, accountID, policyID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete policy from store")
|
||||
@@ -1444,8 +1447,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||
var setupKeys []*SetupKey
|
||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) {
|
||||
var setupKeys []*types.SetupKey
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Find(&setupKeys, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
@@ -1457,8 +1460,8 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking
|
||||
}
|
||||
|
||||
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) {
|
||||
var setupKey *SetupKey
|
||||
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) {
|
||||
var setupKey *types.SetupKey
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
|
||||
if err := result.Error; err != nil {
|
||||
@@ -1473,7 +1476,7 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre
|
||||
}
|
||||
|
||||
// SaveSetupKey saves a setup key to the database.
|
||||
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error {
|
||||
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error)
|
||||
@@ -1485,7 +1488,7 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt
|
||||
|
||||
// DeleteSetupKey deletes a setup key from the database.
|
||||
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&types.SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete setup key from store")
|
||||
@@ -1585,9 +1588,9 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
|
||||
}
|
||||
|
||||
// SaveDNSSettings saves the DNS settings to the store.
|
||||
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
|
||||
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
|
||||
Where(idQueryCondition, accountID).Updates(&types.AccountDNSSettings{DNSSettings: *settings})
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save dns settings to store")
|
||||
@@ -1600,8 +1603,8 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networks.Network, error) {
|
||||
var networks []*networks.Network
|
||||
func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) {
|
||||
var networks []*networkTypes.Network
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&networks, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get networks from the store: %s", result.Error)
|
||||
@@ -1611,8 +1614,8 @@ func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingS
|
||||
return networks, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networks.Network, error) {
|
||||
var network *networks.Network
|
||||
func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) {
|
||||
var network *networkTypes.Network
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&network, accountAndIDQueryCondition, accountID, networkID)
|
||||
if result.Error != nil {
|
||||
@@ -1627,7 +1630,7 @@ func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStren
|
||||
return network, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networks.Network) error {
|
||||
func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(network)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save network to store: %v", result.Error)
|
||||
@@ -1639,7 +1642,7 @@ func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength
|
||||
|
||||
func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&networks.Network{}, accountAndIDQueryCondition, accountID, networkID)
|
||||
Delete(&networkTypes.Network{}, accountAndIDQueryCondition, accountID, networkID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete network from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete network from store")
|
||||
@@ -1652,8 +1655,8 @@ func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStreng
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*networks.NetworkRouter, error) {
|
||||
var netRouters []*networks.NetworkRouter
|
||||
func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) {
|
||||
var netRouters []*routerTypes.NetworkRouter
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID)
|
||||
if result.Error != nil {
|
||||
@@ -1664,8 +1667,8 @@ func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength Lo
|
||||
return netRouters, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*networks.NetworkRouter, error) {
|
||||
var netRouter *networks.NetworkRouter
|
||||
func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) {
|
||||
var netRouter *routerTypes.NetworkRouter
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&netRouter, accountAndIDQueryCondition, accountID, routerID)
|
||||
if result.Error != nil {
|
||||
@@ -1679,7 +1682,7 @@ func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength Lockin
|
||||
return netRouter, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *networks.NetworkRouter) error {
|
||||
func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(router)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error)
|
||||
@@ -1691,7 +1694,7 @@ func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingSt
|
||||
|
||||
func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&networks.NetworkRouter{}, accountAndIDQueryCondition, accountID, routerID)
|
||||
Delete(&routerTypes.NetworkRouter{}, accountAndIDQueryCondition, accountID, routerID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete network router from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete network router from store")
|
||||
@@ -1704,8 +1707,8 @@ func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength Locking
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*networks.NetworkResource, error) {
|
||||
var netResources []*networks.NetworkResource
|
||||
func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) {
|
||||
var netResources []*resourceTypes.NetworkResource
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID)
|
||||
if result.Error != nil {
|
||||
@@ -1716,8 +1719,8 @@ func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength
|
||||
return netResources, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*networks.NetworkResource, error) {
|
||||
var netResources *networks.NetworkResource
|
||||
func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) {
|
||||
var netResources *resourceTypes.NetworkResource
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&netResources, accountAndIDQueryCondition, accountID, resourceID)
|
||||
if result.Error != nil {
|
||||
@@ -1731,7 +1734,7 @@ func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength Lock
|
||||
return netResources, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *networks.NetworkResource) error {
|
||||
func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(resource)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save network resource to store: %v", result.Error)
|
||||
@@ -1743,7 +1746,7 @@ func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength Locking
|
||||
|
||||
func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&networks.NetworkResource{}, accountAndIDQueryCondition, accountID, resourceID)
|
||||
Delete(&resourceTypes.NetworkResource{}, accountAndIDQueryCondition, accountID, resourceID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete network resource from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete network resource from store")
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,18 +14,25 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
route2 "github.com/netbirdio/netbird/route"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func TestSqlite_NewStore(t *testing.T) {
|
||||
@@ -74,7 +81,7 @@ func runLargeTest(t *testing.T, store Store) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
setupKey, _ := types.GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
const numPerAccount = 6000
|
||||
for n := 0; n < numPerAccount; n++ {
|
||||
@@ -87,14 +94,14 @@ func runLargeTest(t *testing.T, store Store) {
|
||||
IP: netIP,
|
||||
Name: peerID,
|
||||
DNSLabel: peerID,
|
||||
UserID: userID,
|
||||
UserID: "testuser",
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
}
|
||||
account.Peers[peerID] = peer
|
||||
group, _ := account.GetGroupAll()
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
user := &User{
|
||||
user := &types.User{
|
||||
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
|
||||
AccountID: account.Id,
|
||||
}
|
||||
@@ -135,7 +142,7 @@ func runLargeTest(t *testing.T, store Store) {
|
||||
}
|
||||
account.NameServerGroups[nameserver.ID] = nameserver
|
||||
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
setupKey, _ := types.GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
}
|
||||
|
||||
@@ -217,7 +224,7 @@ func TestSqlite_SaveAccount(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
setupKey, _ := types.GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
@@ -231,7 +238,7 @@ func TestSqlite_SaveAccount(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
|
||||
setupKey, _ = GenerateDefaultSetupKey()
|
||||
setupKey, _ = types.GenerateDefaultSetupKey()
|
||||
account2.SetupKeys[setupKey.Key] = setupKey
|
||||
account2.Peers["testpeer2"] = &nbpeer.Peer{
|
||||
Key: "peerkey2",
|
||||
@@ -290,14 +297,14 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
testUserID := "testuser"
|
||||
user := NewAdminUser(testUserID)
|
||||
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
||||
user := types.NewAdminUser(testUserID)
|
||||
user.PATs = map[string]*types.PersonalAccessToken{"testtoken": {
|
||||
ID: "testtoken",
|
||||
Name: "test token",
|
||||
}}
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", testUserID, "")
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
setupKey, _ := types.GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
@@ -307,7 +314,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
account.Users[testUserID] = user
|
||||
account.Networks = []*networks.Network{
|
||||
account.Networks = []*networkTypes.Network{
|
||||
{
|
||||
ID: "network_id",
|
||||
AccountID: account.Id,
|
||||
@@ -315,7 +322,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
Description: "network description",
|
||||
},
|
||||
}
|
||||
account.NetworkRouters = []*networks.NetworkRouter{
|
||||
account.NetworkRouters = []*routerTypes.NetworkRouter{
|
||||
{
|
||||
ID: "router_id",
|
||||
NetworkID: account.Networks[0].ID,
|
||||
@@ -325,7 +332,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
Metric: 1,
|
||||
},
|
||||
}
|
||||
account.NetworkResources = []*networks.NetworkResource{
|
||||
account.NetworkResources = []*resourceTypes.NetworkResource{
|
||||
{
|
||||
ID: "resource_id",
|
||||
NetworkID: account.Networks[0].ID,
|
||||
@@ -367,16 +374,16 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id")
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
var rules []*PolicyRule
|
||||
err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
var rules []*types.PolicyRule
|
||||
err = store.(*SqlStore).db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
|
||||
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
|
||||
|
||||
}
|
||||
|
||||
for _, accountUser := range account.Users {
|
||||
var pats []*PersonalAccessToken
|
||||
err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
|
||||
var pats []*types.PersonalAccessToken
|
||||
err = store.(*SqlStore).db.Model(&types.PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
|
||||
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
|
||||
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
|
||||
|
||||
@@ -399,7 +406,7 @@ func TestSqlite_GetAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -422,7 +429,7 @@ func TestSqlite_SavePeer(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -472,7 +479,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -527,7 +534,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -581,7 +588,7 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -604,7 +611,7 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -628,7 +635,7 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -664,7 +671,7 @@ func TestMigrate(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to parse CIDR")
|
||||
|
||||
type network struct {
|
||||
Network
|
||||
types.Network
|
||||
Net net.IPNet `gorm:"serializer:gob"`
|
||||
}
|
||||
|
||||
@@ -679,7 +686,7 @@ func TestMigrate(t *testing.T) {
|
||||
}
|
||||
|
||||
type account struct {
|
||||
Account
|
||||
types.Account
|
||||
Network *network `gorm:"embedded;embeddedPrefix:network_"`
|
||||
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
|
||||
}
|
||||
@@ -739,23 +746,10 @@ func TestMigrate(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func newSqliteStore(t *testing.T) *SqlStore {
|
||||
t.Helper()
|
||||
|
||||
store, err := NewSqliteStore(context.Background(), t.TempDir(), nil)
|
||||
t.Cleanup(func() {
|
||||
store.Close(context.Background())
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, store)
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func newAccount(store Store, id int) error {
|
||||
str := fmt.Sprintf("%s-%d", uuid.New().String(), id)
|
||||
account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com")
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
setupKey, _ := types.GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["p"+str] = &nbpeer.Peer{
|
||||
Key: "peerkey" + str,
|
||||
@@ -794,7 +788,7 @@ func TestPostgresql_SaveAccount(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
setupKey, _ := types.GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
@@ -808,7 +802,7 @@ func TestPostgresql_SaveAccount(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
|
||||
setupKey, _ = GenerateDefaultSetupKey()
|
||||
setupKey, _ = types.GenerateDefaultSetupKey()
|
||||
account2.SetupKeys[setupKey.Key] = setupKey
|
||||
account2.Peers["testpeer2"] = &nbpeer.Peer{
|
||||
Key: "peerkey2",
|
||||
@@ -867,14 +861,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
testUserID := "testuser"
|
||||
user := NewAdminUser(testUserID)
|
||||
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
||||
user := types.NewAdminUser(testUserID)
|
||||
user.PATs = map[string]*types.PersonalAccessToken{"testtoken": {
|
||||
ID: "testtoken",
|
||||
Name: "test token",
|
||||
}}
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", testUserID, "")
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
setupKey, _ := types.GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
@@ -915,16 +909,16 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
|
||||
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id")
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
var rules []*PolicyRule
|
||||
err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
var rules []*types.PolicyRule
|
||||
err = store.(*SqlStore).db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
|
||||
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
|
||||
|
||||
}
|
||||
|
||||
for _, accountUser := range account.Users {
|
||||
var pats []*PersonalAccessToken
|
||||
err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
|
||||
var pats []*types.PersonalAccessToken
|
||||
err = store.(*SqlStore).db.Model(&types.PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
|
||||
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
|
||||
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
|
||||
|
||||
@@ -938,7 +932,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -979,7 +973,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -999,7 +993,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -1017,7 +1011,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -1030,7 +1024,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
|
||||
func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
defer cleanup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1075,7 +1069,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
|
||||
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -1117,7 +1111,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
|
||||
func TestSqlite_GetAccountNetwork(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1140,7 +1134,7 @@ func TestSqlite_GetAccountNetwork(t *testing.T) {
|
||||
|
||||
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1158,14 +1152,14 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
|
||||
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, encodedHashedKey, setupKey.Key)
|
||||
assert.Equal(t, hiddenKey(plainKey, 4), setupKey.KeySecret)
|
||||
assert.Equal(t, types.HiddenKey(plainKey, 4), setupKey.KeySecret)
|
||||
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
|
||||
assert.Equal(t, "Default key", setupKey.Name)
|
||||
}
|
||||
|
||||
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1201,7 +1195,7 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
||||
|
||||
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1232,7 +1226,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_GetAccoundUsers(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1246,7 +1240,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1292,7 +1286,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_GetGroupByName(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1306,7 +1300,7 @@ func TestSqlite_GetGroupByName(t *testing.T) {
|
||||
|
||||
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1322,7 +1316,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
|
||||
|
||||
func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1334,7 +1328,7 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetGroupsByIDs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1377,7 +1371,7 @@ func TestSqlStore_GetGroupsByIDs(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveGroup(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1398,7 +1392,7 @@ func TestSqlStore_SaveGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveGroups(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1423,7 +1417,7 @@ func TestSqlStore_SaveGroups(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteGroup(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1471,7 +1465,7 @@ func TestSqlStore_DeleteGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteGroups(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1518,7 +1512,7 @@ func TestSqlStore_DeleteGroups(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPeerByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1564,7 +1558,7 @@ func TestSqlStore_GetPeerByID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPeersByIDs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1606,7 +1600,7 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPostureChecksByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1652,7 +1646,7 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPostureChecksByIDs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1695,7 +1689,7 @@ func TestSqlStore_GetPostureChecksByIDs(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_SavePostureChecks(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1736,7 +1730,7 @@ func TestSqlStore_SavePostureChecks(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_DeletePostureChecks(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1783,7 +1777,7 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPolicyByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1829,23 +1823,23 @@ func TestSqlStore_GetPolicyByID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_CreatePolicy(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
policy := &Policy{
|
||||
policy := &types.Policy{
|
||||
ID: "policy-id",
|
||||
AccountID: accountID,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupC"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1859,7 +1853,7 @@ func TestSqlStore_CreatePolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_SavePolicy(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1882,7 +1876,7 @@ func TestSqlStore_SavePolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_DeletePolicy(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1898,7 +1892,7 @@ func TestSqlStore_DeletePolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetDNSSettings(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1942,7 +1936,7 @@ func TestSqlStore_GetDNSSettings(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveDNSSettings(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1961,7 +1955,7 @@ func TestSqlStore_SaveDNSSettings(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountNameServerGroups(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1998,7 +1992,7 @@ func TestSqlStore_GetAccountNameServerGroups(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetNameServerByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2044,7 +2038,7 @@ func TestSqlStore_GetNameServerByID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveNameServerGroup(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2076,7 +2070,7 @@ func TestSqlStore_SaveNameServerGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2091,8 +2085,97 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
|
||||
require.Nil(t, nsGroup)
|
||||
}
|
||||
|
||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
|
||||
log.WithContext(ctx).Debugf("creating new account")
|
||||
|
||||
network := types.NewNetwork()
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
users := make(map[string]*types.User)
|
||||
routes := make(map[nbroute.ID]*nbroute.Route)
|
||||
setupKeys := map[string]*types.SetupKey{}
|
||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||
|
||||
owner := types.NewOwnerUser(userID)
|
||||
owner.AccountID = accountID
|
||||
users[userID] = owner
|
||||
|
||||
dnsSettings := types.DNSSettings{
|
||||
DisabledManagementGroups: make([]string, 0),
|
||||
}
|
||||
log.WithContext(ctx).Debugf("created new account %s", accountID)
|
||||
|
||||
acc := &types.Account{
|
||||
Id: accountID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SetupKeys: setupKeys,
|
||||
Network: network,
|
||||
Peers: peers,
|
||||
Users: users,
|
||||
CreatedBy: userID,
|
||||
Domain: domain,
|
||||
Routes: routes,
|
||||
NameServerGroups: nameServersGroups,
|
||||
DNSSettings: dnsSettings,
|
||||
Settings: &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
GroupsPropagationEnabled: true,
|
||||
RegularUsersViewBlocked: true,
|
||||
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
|
||||
},
|
||||
}
|
||||
|
||||
if err := addAllGroup(acc); err != nil {
|
||||
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
|
||||
}
|
||||
return acc
|
||||
}
|
||||
|
||||
// addAllGroup to account object if it doesn't exist
|
||||
func addAllGroup(account *types.Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
allGroup := &nbgroup.Group{
|
||||
ID: xid.New().String(),
|
||||
Name: "All",
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
}
|
||||
for _, peer := range account.Peers {
|
||||
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
||||
}
|
||||
account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup}
|
||||
|
||||
id := xid.New().String()
|
||||
|
||||
defaultPolicy := &types.Policy{
|
||||
ID: id,
|
||||
Name: types.DefaultRuleName,
|
||||
Description: types.DefaultRuleDescription,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: id,
|
||||
Name: types.DefaultRuleName,
|
||||
Description: types.DefaultRuleDescription,
|
||||
Enabled: true,
|
||||
Sources: []string{allGroup.ID},
|
||||
Destinations: []string{allGroup.ID},
|
||||
Bidirectional: true,
|
||||
Protocol: types.PolicyRuleProtocolALL,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account.Policies = []*types.Policy{defaultPolicy}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountNetworks(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2124,7 +2207,7 @@ func TestSqlStore_GetAccountNetworks(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetNetworkByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2170,12 +2253,12 @@ func TestSqlStore_GetNetworkByID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveNetwork(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
network := &networks.Network{
|
||||
network := &networkTypes.Network{
|
||||
ID: "net-id",
|
||||
AccountID: accountID,
|
||||
Name: "net",
|
||||
@@ -2190,7 +2273,7 @@ func TestSqlStore_SaveNetwork(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteNetwork(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2209,7 +2292,7 @@ func TestSqlStore_DeleteNetwork(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetNetworkRoutersByNetID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2242,7 +2325,7 @@ func TestSqlStore_GetNetworkRoutersByNetID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetNetworkRouterByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2288,14 +2371,14 @@ func TestSqlStore_GetNetworkRouterByID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveNetworkRouter(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
networkID := "ct286bi7qv930dsrrug0"
|
||||
|
||||
netRouter, err := networks.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0)
|
||||
netRouter, err := routerTypes.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.SaveNetworkRouter(context.Background(), LockingStrengthUpdate, netRouter)
|
||||
@@ -2307,7 +2390,7 @@ func TestSqlStore_SaveNetworkRouter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteNetworkRouter(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2326,7 +2409,7 @@ func TestSqlStore_DeleteNetworkRouter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetNetworkResourcesByNetID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2359,7 +2442,7 @@ func TestSqlStore_GetNetworkResourcesByNetID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_GetNetworkResourceByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2405,14 +2488,14 @@ func TestSqlStore_GetNetworkResourceByID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveNetworkResource(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
networkID := "ct286bi7qv930dsrrug0"
|
||||
|
||||
netResource, err := networks.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com")
|
||||
netResource, err := resourceTypes.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.SaveNetworkResource(context.Background(), LockingStrengthUpdate, netResource)
|
||||
@@ -2424,7 +2507,7 @@ func TestSqlStore_SaveNetworkResource(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteNetworkResource(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
|
||||
@@ -26,6 +26,9 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/testutil"
|
||||
@@ -42,31 +45,31 @@ const (
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
GetAllAccounts(ctx context.Context) []*Account
|
||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||
GetAllAccounts(ctx context.Context) []*types.Account
|
||||
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
||||
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
|
||||
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
||||
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
||||
GetAccountByUser(ctx context.Context, userID string) (*types.Account, error)
|
||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error)
|
||||
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountIDByUserID(userID string) (string, error)
|
||||
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
||||
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
||||
GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) // todo use key hash later
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error)
|
||||
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error)
|
||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error)
|
||||
SaveAccount(ctx context.Context, account *types.Account) error
|
||||
DeleteAccount(ctx context.Context, account *types.Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
|
||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
|
||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
|
||||
SaveUsers(accountID string, users map[string]*types.User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
@@ -81,10 +84,10 @@ type Store interface {
|
||||
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
||||
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
|
||||
|
||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
|
||||
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error)
|
||||
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error
|
||||
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error
|
||||
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
|
||||
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
@@ -106,11 +109,11 @@ type Store interface {
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)
|
||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
|
||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
|
||||
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
|
||||
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error)
|
||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error)
|
||||
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error
|
||||
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
|
||||
|
||||
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||
@@ -123,7 +126,7 @@ type Store interface {
|
||||
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error)
|
||||
|
||||
GetInstallationID() string
|
||||
SaveInstallationID(ctx context.Context, ID string) error
|
||||
@@ -137,45 +140,45 @@ type Store interface {
|
||||
|
||||
// Close should close the store persisting all unsaved data.
|
||||
Close(ctx context.Context) error
|
||||
// GetStoreEngine should return StoreEngine of the current store implementation.
|
||||
// GetStoreEngine should return Engine of the current store implementation.
|
||||
// This is also a method of metrics.DataSource interface.
|
||||
GetStoreEngine() StoreEngine
|
||||
GetStoreEngine() Engine
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
|
||||
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networks.Network, error)
|
||||
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networks.Network, error)
|
||||
SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networks.Network) error
|
||||
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
|
||||
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
|
||||
SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error
|
||||
DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error
|
||||
|
||||
GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*networks.NetworkRouter, error)
|
||||
GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*networks.NetworkRouter, error)
|
||||
SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *networks.NetworkRouter) error
|
||||
GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error)
|
||||
GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error)
|
||||
SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error
|
||||
DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error
|
||||
|
||||
GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*networks.NetworkResource, error)
|
||||
GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*networks.NetworkResource, error)
|
||||
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *networks.NetworkResource) error
|
||||
GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error)
|
||||
GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error)
|
||||
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error
|
||||
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
|
||||
}
|
||||
|
||||
type StoreEngine string
|
||||
type Engine string
|
||||
|
||||
const (
|
||||
FileStoreEngine StoreEngine = "jsonfile"
|
||||
SqliteStoreEngine StoreEngine = "sqlite"
|
||||
PostgresStoreEngine StoreEngine = "postgres"
|
||||
FileStoreEngine Engine = "jsonfile"
|
||||
SqliteStoreEngine Engine = "sqlite"
|
||||
PostgresStoreEngine Engine = "postgres"
|
||||
|
||||
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
)
|
||||
|
||||
func getStoreEngineFromEnv() StoreEngine {
|
||||
func getStoreEngineFromEnv() Engine {
|
||||
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file.
|
||||
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
value := StoreEngine(strings.ToLower(kind))
|
||||
value := Engine(strings.ToLower(kind))
|
||||
if value == SqliteStoreEngine || value == PostgresStoreEngine {
|
||||
return value
|
||||
}
|
||||
@@ -187,7 +190,7 @@ func getStoreEngineFromEnv() StoreEngine {
|
||||
// If no engine is specified, it attempts to retrieve it from the environment.
|
||||
// If still not specified, it defaults to using SQLite.
|
||||
// Additionally, it handles the migration from a JSON store file to SQLite if applicable.
|
||||
func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) StoreEngine {
|
||||
func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
|
||||
if kind == "" {
|
||||
kind = getStoreEngineFromEnv()
|
||||
if kind == "" {
|
||||
@@ -213,7 +216,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) Store
|
||||
}
|
||||
|
||||
// NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics
|
||||
func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
|
||||
func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
|
||||
kind = getStoreEngine(ctx, dataDir, kind)
|
||||
|
||||
if err := checkFileStoreEngine(kind, dataDir); err != nil {
|
||||
@@ -232,7 +235,7 @@ func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics tel
|
||||
}
|
||||
}
|
||||
|
||||
func checkFileStoreEngine(kind StoreEngine, dataDir string) error {
|
||||
func checkFileStoreEngine(kind Engine, dataDir string) error {
|
||||
if kind == FileStoreEngine {
|
||||
storeFile := filepath.Join(dataDir, storeFileName)
|
||||
if util.FileExists(storeFile) {
|
||||
@@ -259,7 +262,7 @@ func migrate(ctx context.Context, db *gorm.DB) error {
|
||||
func getMigrations(ctx context.Context) []migrationFunc {
|
||||
return []migrationFunc{
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](ctx, db, "network_net")
|
||||
return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](ctx, db, "network")
|
||||
@@ -274,7 +277,7 @@ func getMigrations(ctx context.Context) []migrationFunc {
|
||||
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateSetupKeyToHashedSetupKey[SetupKey](ctx, db)
|
||||
return migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](ctx, db)
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -76,11 +76,3 @@ func BenchmarkTest_StoreRead(b *testing.B) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newStore(t *testing.T) Store {
|
||||
t.Helper()
|
||||
|
||||
store := newSqliteStore(t)
|
||||
|
||||
return store
|
||||
}
|
||||
1182
management/server/types/account.go
Normal file
1182
management/server/types/account.go
Normal file
File diff suppressed because it is too large
Load Diff
16
management/server/types/dns_settings.go
Normal file
16
management/server/types/dns_settings.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package types
|
||||
|
||||
// DNSSettings defines dns settings at the account level
|
||||
type DNSSettings struct {
|
||||
// DisabledManagementGroups groups whose DNS management is disabled
|
||||
DisabledManagementGroups []string `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// Copy returns a copy of the DNS settings
|
||||
func (d DNSSettings) Copy() DNSSettings {
|
||||
settings := DNSSettings{
|
||||
DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)),
|
||||
}
|
||||
copy(settings.DisabledManagementGroups, d.DisabledManagementGroups)
|
||||
return settings
|
||||
}
|
||||
129
management/server/types/firewall_rule.go
Normal file
129
management/server/types/firewall_rule.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const (
|
||||
FirewallRuleDirectionIN = 0
|
||||
FirewallRuleDirectionOUT = 1
|
||||
)
|
||||
|
||||
// FirewallRule is a rule of the firewall.
|
||||
type FirewallRule struct {
|
||||
// PeerIP of the peer
|
||||
PeerIP string
|
||||
|
||||
// Direction of the traffic
|
||||
Direction int
|
||||
|
||||
// Action of the traffic
|
||||
Action string
|
||||
|
||||
// Protocol of the traffic
|
||||
Protocol string
|
||||
|
||||
// Port of the traffic
|
||||
Port string
|
||||
}
|
||||
|
||||
// generateRouteFirewallRules generates a list of firewall rules for a given route.
|
||||
func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
|
||||
rulesExists := make(map[string]struct{})
|
||||
rules := make([]*RouteFirewallRule, 0)
|
||||
|
||||
sourceRanges := make([]string, 0, len(groupPeers))
|
||||
for _, peer := range groupPeers {
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP))
|
||||
}
|
||||
|
||||
baseRule := RouteFirewallRule{
|
||||
SourceRanges: sourceRanges,
|
||||
Action: string(rule.Action),
|
||||
Destination: route.Network.String(),
|
||||
Protocol: string(rule.Protocol),
|
||||
IsDynamic: route.IsDynamic(),
|
||||
}
|
||||
|
||||
// generate rule for port range
|
||||
if len(rule.Ports) == 0 {
|
||||
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
|
||||
} else {
|
||||
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
|
||||
|
||||
}
|
||||
|
||||
// TODO: generate IPv6 rules for dynamic routes
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// generateRulesForPeer generates rules for a given peer based on ports and port ranges.
|
||||
func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
|
||||
rules := make([]*RouteFirewallRule, 0)
|
||||
|
||||
ruleIDBase := generateRuleIDBase(rule, baseRule)
|
||||
if len(rule.Ports) == 0 {
|
||||
if len(rule.PortRanges) == 0 {
|
||||
if _, ok := rulesExists[ruleIDBase]; !ok {
|
||||
rulesExists[ruleIDBase] = struct{}{}
|
||||
rules = append(rules, &baseRule)
|
||||
}
|
||||
} else {
|
||||
for _, portRange := range rule.PortRanges {
|
||||
ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End)
|
||||
if _, ok := rulesExists[ruleID]; !ok {
|
||||
rulesExists[ruleID] = struct{}{}
|
||||
pr := baseRule
|
||||
pr.PortRange = portRange
|
||||
rules = append(rules, &pr)
|
||||
}
|
||||
}
|
||||
}
|
||||
return rules
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// generateRulesWithPorts generates rules when specific ports are provided.
|
||||
func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
|
||||
rules := make([]*RouteFirewallRule, 0)
|
||||
ruleIDBase := generateRuleIDBase(rule, baseRule)
|
||||
|
||||
for _, port := range rule.Ports {
|
||||
ruleID := ruleIDBase + port
|
||||
if _, ok := rulesExists[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
rulesExists[ruleID] = struct{}{}
|
||||
|
||||
pr := baseRule
|
||||
p, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
pr.Port = uint16(p)
|
||||
rules = append(rules, &pr)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// generateRuleIDBase generates the base rule ID for checking duplicates.
|
||||
func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string {
|
||||
return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(FirewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package types
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
@@ -43,7 +43,7 @@ type Network struct {
|
||||
// Used to synchronize state to the client apps.
|
||||
Serial uint64
|
||||
|
||||
mu sync.Mutex `json:"-" gorm:"-"`
|
||||
Mu sync.Mutex `json:"-" gorm:"-"`
|
||||
}
|
||||
|
||||
// NewNetwork creates a new Network initializing it with a Serial=0
|
||||
@@ -66,15 +66,15 @@ func NewNetwork() *Network {
|
||||
|
||||
// IncSerial increments Serial by 1 reflecting that the network state has been changed
|
||||
func (n *Network) IncSerial() {
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
n.Mu.Lock()
|
||||
defer n.Mu.Unlock()
|
||||
n.Serial++
|
||||
}
|
||||
|
||||
// CurrentSerial returns the Network.Serial of the network (latest state id)
|
||||
func (n *Network) CurrentSerial() uint64 {
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
n.Mu.Lock()
|
||||
defer n.Mu.Unlock()
|
||||
return n.Serial
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package types
|
||||
|
||||
import (
|
||||
"net"
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package types
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package types
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
116
management/server/types/policy.go
Normal file
116
management/server/types/policy.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package types
|
||||
|
||||
const (
|
||||
// PolicyTrafficActionAccept indicates that the traffic is accepted
|
||||
PolicyTrafficActionAccept = PolicyTrafficActionType("accept")
|
||||
// PolicyTrafficActionDrop indicates that the traffic is dropped
|
||||
PolicyTrafficActionDrop = PolicyTrafficActionType("drop")
|
||||
)
|
||||
|
||||
const (
|
||||
// PolicyRuleProtocolALL type of traffic
|
||||
PolicyRuleProtocolALL = PolicyRuleProtocolType("all")
|
||||
// PolicyRuleProtocolTCP type of traffic
|
||||
PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp")
|
||||
// PolicyRuleProtocolUDP type of traffic
|
||||
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
|
||||
// PolicyRuleProtocolICMP type of traffic
|
||||
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
|
||||
)
|
||||
|
||||
const (
|
||||
// PolicyRuleFlowDirect allows traffic from source to destination
|
||||
PolicyRuleFlowDirect = PolicyRuleDirection("direct")
|
||||
// PolicyRuleFlowBidirect allows traffic to both directions
|
||||
PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect")
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultRuleName is a name for the Default rule that is created for every account
|
||||
DefaultRuleName = "Default"
|
||||
// DefaultRuleDescription is a description for the Default rule that is created for every account
|
||||
DefaultRuleDescription = "This is a default rule that allows connections between all the resources"
|
||||
// DefaultPolicyName is a name for the Default policy that is created for every account
|
||||
DefaultPolicyName = "Default"
|
||||
// DefaultPolicyDescription is a description for the Default policy that is created for every account
|
||||
DefaultPolicyDescription = "This is a default policy that allows connections between all the resources"
|
||||
)
|
||||
|
||||
// PolicyUpdateOperation operation object with type and values to be applied
|
||||
type PolicyUpdateOperation struct {
|
||||
Type PolicyUpdateOperationType
|
||||
Values []string
|
||||
}
|
||||
|
||||
// Policy of the Rego query
|
||||
type Policy struct {
|
||||
// ID of the policy'
|
||||
ID string `gorm:"primaryKey"`
|
||||
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// Name of the Policy
|
||||
Name string
|
||||
|
||||
// Description of the policy visible in the UI
|
||||
Description string
|
||||
|
||||
// Enabled status of the policy
|
||||
Enabled bool
|
||||
|
||||
// Rules of the policy
|
||||
Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
|
||||
// SourcePostureChecks are ID references to Posture checks for policy source groups
|
||||
SourcePostureChecks []string `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// Copy returns a copy of the policy.
|
||||
func (p *Policy) Copy() *Policy {
|
||||
c := &Policy{
|
||||
ID: p.ID,
|
||||
AccountID: p.AccountID,
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
Enabled: p.Enabled,
|
||||
Rules: make([]*PolicyRule, len(p.Rules)),
|
||||
SourcePostureChecks: make([]string, len(p.SourcePostureChecks)),
|
||||
}
|
||||
for i, r := range p.Rules {
|
||||
c.Rules[i] = r.Copy()
|
||||
}
|
||||
copy(c.SourcePostureChecks, p.SourcePostureChecks)
|
||||
return c
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to this policy
|
||||
func (p *Policy) EventMeta() map[string]any {
|
||||
return map[string]any{"name": p.Name}
|
||||
}
|
||||
|
||||
// UpgradeAndFix different version of policies to latest version
|
||||
func (p *Policy) UpgradeAndFix() {
|
||||
for _, r := range p.Rules {
|
||||
// start migrate from version v0.20.3
|
||||
if r.Protocol == "" {
|
||||
r.Protocol = PolicyRuleProtocolALL
|
||||
}
|
||||
if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional {
|
||||
r.Bidirectional = true
|
||||
}
|
||||
// -- v0.20.4
|
||||
}
|
||||
}
|
||||
|
||||
// RuleGroups returns a list of all groups referenced in the policy's rules,
|
||||
// including sources and destinations.
|
||||
func (p *Policy) RuleGroups() []string {
|
||||
groups := make([]string, 0)
|
||||
for _, rule := range p.Rules {
|
||||
groups = append(groups, rule.Sources...)
|
||||
groups = append(groups, rule.Destinations...)
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
81
management/server/types/policyrule.go
Normal file
81
management/server/types/policyrule.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package types
|
||||
|
||||
// PolicyUpdateOperationType operation type
|
||||
type PolicyUpdateOperationType int
|
||||
|
||||
// PolicyTrafficActionType action type for the firewall
|
||||
type PolicyTrafficActionType string
|
||||
|
||||
// PolicyRuleProtocolType type of traffic
|
||||
type PolicyRuleProtocolType string
|
||||
|
||||
// PolicyRuleDirection direction of traffic
|
||||
type PolicyRuleDirection string
|
||||
|
||||
// RulePortRange represents a range of ports for a firewall rule.
|
||||
type RulePortRange struct {
|
||||
Start uint16
|
||||
End uint16
|
||||
}
|
||||
|
||||
// PolicyRule is the metadata of the policy
|
||||
type PolicyRule struct {
|
||||
// ID of the policy rule
|
||||
ID string `gorm:"primaryKey"`
|
||||
|
||||
// PolicyID is a reference to Policy that this object belongs
|
||||
PolicyID string `json:"-" gorm:"index"`
|
||||
|
||||
// Name of the rule visible in the UI
|
||||
Name string
|
||||
|
||||
// Description of the rule visible in the UI
|
||||
Description string
|
||||
|
||||
// Enabled status of rule in the system
|
||||
Enabled bool
|
||||
|
||||
// Action policy accept or drops packets
|
||||
Action PolicyTrafficActionType
|
||||
|
||||
// Destinations policy destination groups
|
||||
Destinations []string `gorm:"serializer:json"`
|
||||
|
||||
// Sources policy source groups
|
||||
Sources []string `gorm:"serializer:json"`
|
||||
|
||||
// Bidirectional define if the rule is applicable in both directions, sources, and destinations
|
||||
Bidirectional bool
|
||||
|
||||
// Protocol type of the traffic
|
||||
Protocol PolicyRuleProtocolType
|
||||
|
||||
// Ports or it ranges list
|
||||
Ports []string `gorm:"serializer:json"`
|
||||
|
||||
// PortRanges a list of port ranges.
|
||||
PortRanges []RulePortRange `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// Copy returns a copy of a policy rule
|
||||
func (pm *PolicyRule) Copy() *PolicyRule {
|
||||
rule := &PolicyRule{
|
||||
ID: pm.ID,
|
||||
PolicyID: pm.PolicyID,
|
||||
Name: pm.Name,
|
||||
Description: pm.Description,
|
||||
Enabled: pm.Enabled,
|
||||
Action: pm.Action,
|
||||
Destinations: make([]string, len(pm.Destinations)),
|
||||
Sources: make([]string, len(pm.Sources)),
|
||||
Bidirectional: pm.Bidirectional,
|
||||
Protocol: pm.Protocol,
|
||||
Ports: make([]string, len(pm.Ports)),
|
||||
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
|
||||
}
|
||||
copy(rule.Destinations, pm.Destinations)
|
||||
copy(rule.Sources, pm.Sources)
|
||||
copy(rule.Ports, pm.Ports)
|
||||
copy(rule.PortRanges, pm.PortRanges)
|
||||
return rule
|
||||
}
|
||||
25
management/server/types/route_firewall_rule.go
Normal file
25
management/server/types/route_firewall_rule.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package types
|
||||
|
||||
// RouteFirewallRule a firewall rule applicable for a routed network.
|
||||
type RouteFirewallRule struct {
|
||||
// SourceRanges IP ranges of the routing peers.
|
||||
SourceRanges []string
|
||||
|
||||
// Action of the traffic when the rule is applicable
|
||||
Action string
|
||||
|
||||
// Destination a network prefix for the routed traffic
|
||||
Destination string
|
||||
|
||||
// Protocol of the traffic
|
||||
Protocol string
|
||||
|
||||
// Port of the traffic
|
||||
Port uint16
|
||||
|
||||
// PortRange represents the range of ports for a firewall rule
|
||||
PortRange RulePortRange
|
||||
|
||||
// isDynamic indicates whether the rule is for DNS routing
|
||||
IsDynamic bool
|
||||
}
|
||||
63
management/server/types/settings.go
Normal file
63
management/server/types/settings.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
)
|
||||
|
||||
// Settings represents Account settings structure that can be modified via API and Dashboard
|
||||
type Settings struct {
|
||||
// PeerLoginExpirationEnabled globally enables or disables peer login expiration
|
||||
PeerLoginExpirationEnabled bool
|
||||
|
||||
// PeerLoginExpiration is a setting that indicates when peer login expires.
|
||||
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
||||
PeerLoginExpiration time.Duration
|
||||
|
||||
// PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration
|
||||
PeerInactivityExpirationEnabled bool
|
||||
|
||||
// PeerInactivityExpiration is a setting that indicates when peer inactivity expires.
|
||||
// Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true.
|
||||
PeerInactivityExpiration time.Duration
|
||||
|
||||
// RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements
|
||||
RegularUsersViewBlocked bool
|
||||
|
||||
// GroupsPropagationEnabled allows to propagate auto groups from the user to the peer
|
||||
GroupsPropagationEnabled bool
|
||||
|
||||
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
|
||||
// and add it to account groups.
|
||||
JWTGroupsEnabled bool
|
||||
|
||||
// JWTGroupsClaimName from which we extract groups name to add it to account groups
|
||||
JWTGroupsClaimName string
|
||||
|
||||
// JWTAllowGroups list of groups to which users are allowed access
|
||||
JWTAllowGroups []string `gorm:"serializer:json"`
|
||||
|
||||
// Extra is a dictionary of Account settings
|
||||
Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
|
||||
}
|
||||
|
||||
// Copy copies the Settings struct
|
||||
func (s *Settings) Copy() *Settings {
|
||||
settings := &Settings{
|
||||
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: s.PeerLoginExpiration,
|
||||
JWTGroupsEnabled: s.JWTGroupsEnabled,
|
||||
JWTGroupsClaimName: s.JWTGroupsClaimName,
|
||||
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
|
||||
JWTAllowGroups: s.JWTAllowGroups,
|
||||
RegularUsersViewBlocked: s.RegularUsersViewBlocked,
|
||||
|
||||
PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled,
|
||||
PeerInactivityExpiration: s.PeerInactivityExpiration,
|
||||
}
|
||||
if s.Extra != nil {
|
||||
settings.Extra = s.Extra.Copy()
|
||||
}
|
||||
return settings
|
||||
}
|
||||
181
management/server/types/setupkey.go
Normal file
181
management/server/types/setupkey.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"hash/fnv"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
// SetupKeyReusable is a multi-use key (can be used for multiple machines)
|
||||
SetupKeyReusable SetupKeyType = "reusable"
|
||||
// SetupKeyOneOff is a single use key (can be used only once)
|
||||
SetupKeyOneOff SetupKeyType = "one-off"
|
||||
// DefaultSetupKeyDuration = 1 month
|
||||
DefaultSetupKeyDuration = 24 * 30 * time.Hour
|
||||
// DefaultSetupKeyName is a default name of the default setup key
|
||||
DefaultSetupKeyName = "Default key"
|
||||
// SetupKeyUnlimitedUsage indicates an unlimited usage of a setup key
|
||||
SetupKeyUnlimitedUsage = 0
|
||||
)
|
||||
|
||||
// SetupKeyType is the type of setup key
|
||||
type SetupKeyType string
|
||||
|
||||
// SetupKey represents a pre-authorized key used to register machines (peers)
|
||||
type SetupKey struct {
|
||||
Id string
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
Key string
|
||||
KeySecret string
|
||||
Name string
|
||||
Type SetupKeyType
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime:false"`
|
||||
// Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes)
|
||||
Revoked bool
|
||||
// UsedTimes indicates how many times the key was used
|
||||
UsedTimes int
|
||||
// LastUsed last time the key was used for peer registration
|
||||
LastUsed time.Time
|
||||
// AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register
|
||||
AutoGroups []string `gorm:"serializer:json"`
|
||||
// UsageLimit indicates the number of times this key can be used to enroll a machine.
|
||||
// The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int
|
||||
// Ephemeral indicate if the peers will be ephemeral or not
|
||||
Ephemeral bool
|
||||
}
|
||||
|
||||
// Copy copies SetupKey to a new object
|
||||
func (key *SetupKey) Copy() *SetupKey {
|
||||
autoGroups := make([]string, len(key.AutoGroups))
|
||||
copy(autoGroups, key.AutoGroups)
|
||||
if key.UpdatedAt.IsZero() {
|
||||
key.UpdatedAt = key.CreatedAt
|
||||
}
|
||||
return &SetupKey{
|
||||
Id: key.Id,
|
||||
AccountID: key.AccountID,
|
||||
Key: key.Key,
|
||||
KeySecret: key.KeySecret,
|
||||
Name: key.Name,
|
||||
Type: key.Type,
|
||||
CreatedAt: key.CreatedAt,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
UpdatedAt: key.UpdatedAt,
|
||||
Revoked: key.Revoked,
|
||||
UsedTimes: key.UsedTimes,
|
||||
LastUsed: key.LastUsed,
|
||||
AutoGroups: autoGroups,
|
||||
UsageLimit: key.UsageLimit,
|
||||
Ephemeral: key.Ephemeral,
|
||||
}
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to the setup key
|
||||
func (key *SetupKey) EventMeta() map[string]any {
|
||||
return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret}
|
||||
}
|
||||
|
||||
// HiddenKey returns the Key value hidden with "*" and a 5 character prefix.
|
||||
// E.g., "831F6*******************************"
|
||||
func HiddenKey(key string, length int) string {
|
||||
prefix := key[0:5]
|
||||
if length > utf8.RuneCountInString(key) {
|
||||
length = utf8.RuneCountInString(key) - len(prefix)
|
||||
}
|
||||
return prefix + strings.Repeat("*", length)
|
||||
}
|
||||
|
||||
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
|
||||
func (key *SetupKey) IncrementUsage() *SetupKey {
|
||||
c := key.Copy()
|
||||
c.UsedTimes++
|
||||
c.LastUsed = time.Now().UTC()
|
||||
return c
|
||||
}
|
||||
|
||||
// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to
|
||||
func (key *SetupKey) IsValid() bool {
|
||||
return !key.IsRevoked() && !key.IsExpired() && !key.IsOverUsed()
|
||||
}
|
||||
|
||||
// IsRevoked if key was revoked
|
||||
func (key *SetupKey) IsRevoked() bool {
|
||||
return key.Revoked
|
||||
}
|
||||
|
||||
// IsExpired if key was expired
|
||||
func (key *SetupKey) IsExpired() bool {
|
||||
if key.ExpiresAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(key.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage.
|
||||
func (key *SetupKey) IsOverUsed() bool {
|
||||
limit := key.UsageLimit
|
||||
if key.Type == SetupKeyOneOff {
|
||||
limit = 1
|
||||
}
|
||||
return limit > 0 && key.UsedTimes >= limit
|
||||
}
|
||||
|
||||
// GenerateSetupKey generates a new setup key
|
||||
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string,
|
||||
usageLimit int, ephemeral bool) (*SetupKey, string) {
|
||||
key := strings.ToUpper(uuid.New().String())
|
||||
limit := usageLimit
|
||||
if t == SetupKeyOneOff {
|
||||
limit = 1
|
||||
}
|
||||
|
||||
expiresAt := time.Time{}
|
||||
if validFor != 0 {
|
||||
expiresAt = time.Now().UTC().Add(validFor)
|
||||
}
|
||||
|
||||
hashedKey := sha256.Sum256([]byte(key))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
|
||||
return &SetupKey{
|
||||
Id: strconv.Itoa(int(Hash(key))),
|
||||
Key: encodedHashedKey,
|
||||
KeySecret: HiddenKey(key, 4),
|
||||
Name: name,
|
||||
Type: t,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: expiresAt,
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
Revoked: false,
|
||||
UsedTimes: 0,
|
||||
AutoGroups: autoGroups,
|
||||
UsageLimit: limit,
|
||||
Ephemeral: ephemeral,
|
||||
}, key
|
||||
}
|
||||
|
||||
// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration
|
||||
func GenerateDefaultSetupKey() (*SetupKey, string) {
|
||||
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{},
|
||||
SetupKeyUnlimitedUsage, false)
|
||||
}
|
||||
|
||||
func Hash(s string) uint32 {
|
||||
h := fnv.New32a()
|
||||
_, err := h.Write([]byte(s))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return h.Sum32()
|
||||
}
|
||||
231
management/server/types/user.go
Normal file
231
management/server/types/user.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||
)
|
||||
|
||||
const (
|
||||
UserRoleOwner UserRole = "owner"
|
||||
UserRoleAdmin UserRole = "admin"
|
||||
UserRoleUser UserRole = "user"
|
||||
UserRoleUnknown UserRole = "unknown"
|
||||
UserRoleBillingAdmin UserRole = "billing_admin"
|
||||
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusDisabled UserStatus = "disabled"
|
||||
UserStatusInvited UserStatus = "invited"
|
||||
|
||||
UserIssuedAPI = "api"
|
||||
UserIssuedIntegration = "integration"
|
||||
)
|
||||
|
||||
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown
|
||||
func StrRoleToUserRole(strRole string) UserRole {
|
||||
switch strings.ToLower(strRole) {
|
||||
case "owner":
|
||||
return UserRoleOwner
|
||||
case "admin":
|
||||
return UserRoleAdmin
|
||||
case "user":
|
||||
return UserRoleUser
|
||||
case "billing_admin":
|
||||
return UserRoleBillingAdmin
|
||||
default:
|
||||
return UserRoleUnknown
|
||||
}
|
||||
}
|
||||
|
||||
// UserStatus is the status of a User
|
||||
type UserStatus string
|
||||
|
||||
// UserRole is the role of a User
|
||||
type UserRole string
|
||||
|
||||
type UserInfo struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
Status string `json:"-"`
|
||||
IsServiceUser bool `json:"is_service_user"`
|
||||
IsBlocked bool `json:"is_blocked"`
|
||||
NonDeletable bool `json:"non_deletable"`
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
Issued string `json:"issued"`
|
||||
IntegrationReference integration_reference.IntegrationReference `json:"-"`
|
||||
Permissions UserPermissions `json:"permissions"`
|
||||
}
|
||||
|
||||
type UserPermissions struct {
|
||||
DashboardView string `json:"dashboard_view"`
|
||||
}
|
||||
|
||||
// User represents a user of the system
|
||||
type User struct {
|
||||
Id string `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
Role UserRole
|
||||
IsServiceUser bool
|
||||
// NonDeletable indicates whether the service user can be deleted
|
||||
NonDeletable bool
|
||||
// ServiceUserName is only set if IsServiceUser is true
|
||||
ServiceUserName string
|
||||
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string `gorm:"serializer:json"`
|
||||
PATs map[string]*PersonalAccessToken `gorm:"-"`
|
||||
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"`
|
||||
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||
Blocked bool
|
||||
// LastLogin is the last time the user logged in to IdP
|
||||
LastLogin time.Time
|
||||
// CreatedAt records the time the user was created
|
||||
CreatedAt time.Time
|
||||
|
||||
// Issued of the user
|
||||
Issued string `gorm:"default:api"`
|
||||
|
||||
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
|
||||
}
|
||||
|
||||
// IsBlocked returns true if the user is blocked, false otherwise
|
||||
func (u *User) IsBlocked() bool {
|
||||
return u.Blocked
|
||||
}
|
||||
|
||||
func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
|
||||
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
|
||||
}
|
||||
|
||||
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
|
||||
func (u *User) HasAdminPower() bool {
|
||||
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
|
||||
}
|
||||
|
||||
// IsAdminOrServiceUser checks if the user has admin power or is a service user.
|
||||
func (u *User) IsAdminOrServiceUser() bool {
|
||||
return u.HasAdminPower() || u.IsServiceUser
|
||||
}
|
||||
|
||||
// IsRegularUser checks if the user is a regular user.
|
||||
func (u *User) IsRegularUser() bool {
|
||||
return !u.HasAdminPower() && !u.IsServiceUser
|
||||
}
|
||||
|
||||
// ToUserInfo converts a User object to a UserInfo object.
|
||||
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
|
||||
autoGroups := u.AutoGroups
|
||||
if autoGroups == nil {
|
||||
autoGroups = []string{}
|
||||
}
|
||||
|
||||
dashboardViewPermissions := "full"
|
||||
if !u.HasAdminPower() {
|
||||
dashboardViewPermissions = "limited"
|
||||
if settings.RegularUsersViewBlocked {
|
||||
dashboardViewPermissions = "blocked"
|
||||
}
|
||||
}
|
||||
|
||||
if userData == nil {
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: "",
|
||||
Name: u.ServiceUserName,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
IsBlocked: u.Blocked,
|
||||
LastLogin: u.LastLogin,
|
||||
Issued: u.Issued,
|
||||
Permissions: UserPermissions{
|
||||
DashboardView: dashboardViewPermissions,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
if userData.ID != u.Id {
|
||||
return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id)
|
||||
}
|
||||
|
||||
userStatus := UserStatusActive
|
||||
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
|
||||
userStatus = UserStatusInvited
|
||||
}
|
||||
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: userData.Email,
|
||||
Name: userData.Name,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: autoGroups,
|
||||
Status: string(userStatus),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
IsBlocked: u.Blocked,
|
||||
LastLogin: u.LastLogin,
|
||||
Issued: u.Issued,
|
||||
Permissions: UserPermissions{
|
||||
DashboardView: dashboardViewPermissions,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Copy the user
|
||||
func (u *User) Copy() *User {
|
||||
autoGroups := make([]string, len(u.AutoGroups))
|
||||
copy(autoGroups, u.AutoGroups)
|
||||
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
|
||||
for k, v := range u.PATs {
|
||||
pats[k] = v.Copy()
|
||||
}
|
||||
return &User{
|
||||
Id: u.Id,
|
||||
AccountID: u.AccountID,
|
||||
Role: u.Role,
|
||||
AutoGroups: autoGroups,
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
NonDeletable: u.NonDeletable,
|
||||
ServiceUserName: u.ServiceUserName,
|
||||
PATs: pats,
|
||||
Blocked: u.Blocked,
|
||||
LastLogin: u.LastLogin,
|
||||
CreatedAt: u.CreatedAt,
|
||||
Issued: u.Issued,
|
||||
IntegrationReference: u.IntegrationReference,
|
||||
}
|
||||
}
|
||||
|
||||
// NewUser creates a new user
|
||||
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
|
||||
return &User{
|
||||
Id: id,
|
||||
Role: role,
|
||||
IsServiceUser: isServiceUser,
|
||||
NonDeletable: nonDeletable,
|
||||
ServiceUserName: serviceUserName,
|
||||
AutoGroups: autoGroups,
|
||||
Issued: issued,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewRegularUser creates a new user with role UserRoleUser
|
||||
func NewRegularUser(id string) *User {
|
||||
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
|
||||
}
|
||||
|
||||
// NewAdminUser creates a new user with role UserRoleAdmin
|
||||
func NewAdminUser(id string) *User {
|
||||
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
|
||||
}
|
||||
|
||||
// NewOwnerUser creates a new user with role UserRoleOwner
|
||||
func NewOwnerUser(id string) *User {
|
||||
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
|
||||
}
|
||||
@@ -9,13 +9,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const channelBufferSize = 100
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
NetworkMap *NetworkMap
|
||||
NetworkMap *types.NetworkMap
|
||||
}
|
||||
|
||||
type PeersUpdateManager struct {
|
||||
|
||||
@@ -15,215 +15,16 @@ import (
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
)
|
||||
|
||||
const (
|
||||
UserRoleOwner UserRole = "owner"
|
||||
UserRoleAdmin UserRole = "admin"
|
||||
UserRoleUser UserRole = "user"
|
||||
UserRoleUnknown UserRole = "unknown"
|
||||
UserRoleBillingAdmin UserRole = "billing_admin"
|
||||
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusDisabled UserStatus = "disabled"
|
||||
UserStatusInvited UserStatus = "invited"
|
||||
|
||||
UserIssuedAPI = "api"
|
||||
UserIssuedIntegration = "integration"
|
||||
)
|
||||
|
||||
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown
|
||||
func StrRoleToUserRole(strRole string) UserRole {
|
||||
switch strings.ToLower(strRole) {
|
||||
case "owner":
|
||||
return UserRoleOwner
|
||||
case "admin":
|
||||
return UserRoleAdmin
|
||||
case "user":
|
||||
return UserRoleUser
|
||||
case "billing_admin":
|
||||
return UserRoleBillingAdmin
|
||||
default:
|
||||
return UserRoleUnknown
|
||||
}
|
||||
}
|
||||
|
||||
// UserStatus is the status of a User
|
||||
type UserStatus string
|
||||
|
||||
// UserRole is the role of a User
|
||||
type UserRole string
|
||||
|
||||
// User represents a user of the system
|
||||
type User struct {
|
||||
Id string `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
Role UserRole
|
||||
IsServiceUser bool
|
||||
// NonDeletable indicates whether the service user can be deleted
|
||||
NonDeletable bool
|
||||
// ServiceUserName is only set if IsServiceUser is true
|
||||
ServiceUserName string
|
||||
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string `gorm:"serializer:json"`
|
||||
PATs map[string]*PersonalAccessToken `gorm:"-"`
|
||||
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"`
|
||||
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||
Blocked bool
|
||||
// LastLogin is the last time the user logged in to IdP
|
||||
LastLogin time.Time
|
||||
// CreatedAt records the time the user was created
|
||||
CreatedAt time.Time
|
||||
|
||||
// Issued of the user
|
||||
Issued string `gorm:"default:api"`
|
||||
|
||||
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
|
||||
}
|
||||
|
||||
// IsBlocked returns true if the user is blocked, false otherwise
|
||||
func (u *User) IsBlocked() bool {
|
||||
return u.Blocked
|
||||
}
|
||||
|
||||
func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
|
||||
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
|
||||
}
|
||||
|
||||
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
|
||||
func (u *User) HasAdminPower() bool {
|
||||
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
|
||||
}
|
||||
|
||||
// IsAdminOrServiceUser checks if the user has admin power or is a service user.
|
||||
func (u *User) IsAdminOrServiceUser() bool {
|
||||
return u.HasAdminPower() || u.IsServiceUser
|
||||
}
|
||||
|
||||
// IsRegularUser checks if the user is a regular user.
|
||||
func (u *User) IsRegularUser() bool {
|
||||
return !u.HasAdminPower() && !u.IsServiceUser
|
||||
}
|
||||
|
||||
// ToUserInfo converts a User object to a UserInfo object.
|
||||
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
|
||||
autoGroups := u.AutoGroups
|
||||
if autoGroups == nil {
|
||||
autoGroups = []string{}
|
||||
}
|
||||
|
||||
dashboardViewPermissions := "full"
|
||||
if !u.HasAdminPower() {
|
||||
dashboardViewPermissions = "limited"
|
||||
if settings.RegularUsersViewBlocked {
|
||||
dashboardViewPermissions = "blocked"
|
||||
}
|
||||
}
|
||||
|
||||
if userData == nil {
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: "",
|
||||
Name: u.ServiceUserName,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
IsBlocked: u.Blocked,
|
||||
LastLogin: u.LastLogin,
|
||||
Issued: u.Issued,
|
||||
Permissions: UserPermissions{
|
||||
DashboardView: dashboardViewPermissions,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
if userData.ID != u.Id {
|
||||
return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id)
|
||||
}
|
||||
|
||||
userStatus := UserStatusActive
|
||||
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
|
||||
userStatus = UserStatusInvited
|
||||
}
|
||||
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: userData.Email,
|
||||
Name: userData.Name,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: autoGroups,
|
||||
Status: string(userStatus),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
IsBlocked: u.Blocked,
|
||||
LastLogin: u.LastLogin,
|
||||
Issued: u.Issued,
|
||||
Permissions: UserPermissions{
|
||||
DashboardView: dashboardViewPermissions,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Copy the user
|
||||
func (u *User) Copy() *User {
|
||||
autoGroups := make([]string, len(u.AutoGroups))
|
||||
copy(autoGroups, u.AutoGroups)
|
||||
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
|
||||
for k, v := range u.PATs {
|
||||
pats[k] = v.Copy()
|
||||
}
|
||||
return &User{
|
||||
Id: u.Id,
|
||||
AccountID: u.AccountID,
|
||||
Role: u.Role,
|
||||
AutoGroups: autoGroups,
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
NonDeletable: u.NonDeletable,
|
||||
ServiceUserName: u.ServiceUserName,
|
||||
PATs: pats,
|
||||
Blocked: u.Blocked,
|
||||
LastLogin: u.LastLogin,
|
||||
CreatedAt: u.CreatedAt,
|
||||
Issued: u.Issued,
|
||||
IntegrationReference: u.IntegrationReference,
|
||||
}
|
||||
}
|
||||
|
||||
// NewUser creates a new user
|
||||
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
|
||||
return &User{
|
||||
Id: id,
|
||||
Role: role,
|
||||
IsServiceUser: isServiceUser,
|
||||
NonDeletable: nonDeletable,
|
||||
ServiceUserName: serviceUserName,
|
||||
AutoGroups: autoGroups,
|
||||
Issued: issued,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewRegularUser creates a new user with role UserRoleUser
|
||||
func NewRegularUser(id string) *User {
|
||||
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
|
||||
}
|
||||
|
||||
// NewAdminUser creates a new user with role UserRoleAdmin
|
||||
func NewAdminUser(id string) *User {
|
||||
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
|
||||
}
|
||||
|
||||
// NewOwnerUser creates a new user with role UserRoleOwner
|
||||
func NewOwnerUser(id string) *User {
|
||||
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
|
||||
}
|
||||
|
||||
// createServiceUser creates a new service user under the given account.
|
||||
func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
|
||||
func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role types.UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*types.UserInfo, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -240,12 +41,12 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can create service users")
|
||||
}
|
||||
|
||||
if role == UserRoleOwner {
|
||||
if role == types.UserRoleOwner {
|
||||
return nil, status.Errorf(status.InvalidArgument, "can't create a service user with owner role")
|
||||
}
|
||||
|
||||
newUserID := uuid.New().String()
|
||||
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI)
|
||||
newUser := types.NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, types.UserIssuedAPI)
|
||||
log.WithContext(ctx).Debugf("New User: %v", newUser)
|
||||
account.Users[newUserID] = newUser
|
||||
|
||||
@@ -257,29 +58,29 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
|
||||
meta := map[string]any{"name": newUser.ServiceUserName}
|
||||
am.StoreEvent(ctx, initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
|
||||
|
||||
return &UserInfo{
|
||||
return &types.UserInfo{
|
||||
ID: newUser.Id,
|
||||
Email: "",
|
||||
Name: newUser.ServiceUserName,
|
||||
Role: string(newUser.Role),
|
||||
AutoGroups: newUser.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
Status: string(types.UserStatusActive),
|
||||
IsServiceUser: true,
|
||||
LastLogin: time.Time{},
|
||||
Issued: UserIssuedAPI,
|
||||
Issued: types.UserIssuedAPI,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user under the given account. Effectively this is a user invite.
|
||||
func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *UserInfo) (*UserInfo, error) {
|
||||
func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *types.UserInfo) (*types.UserInfo, error) {
|
||||
if user.IsServiceUser {
|
||||
return am.createServiceUser(ctx, accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups)
|
||||
return am.createServiceUser(ctx, accountID, userID, types.StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups)
|
||||
}
|
||||
return am.inviteNewUser(ctx, accountID, userID, user)
|
||||
}
|
||||
|
||||
// inviteNewUser Invites a USer to a given account and creates reference in datastore
|
||||
func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) {
|
||||
func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -291,14 +92,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
return nil, fmt.Errorf("provided user update is nil")
|
||||
}
|
||||
|
||||
invitedRole := StrRoleToUserRole(invite.Role)
|
||||
invitedRole := types.StrRoleToUserRole(invite.Role)
|
||||
|
||||
switch {
|
||||
case invite.Name == "":
|
||||
return nil, status.Errorf(status.InvalidArgument, "name can't be empty")
|
||||
case invite.Email == "":
|
||||
return nil, status.Errorf(status.InvalidArgument, "email can't be empty")
|
||||
case invitedRole == UserRoleOwner:
|
||||
case invitedRole == types.UserRoleOwner:
|
||||
return nil, status.Errorf(status.InvalidArgument, "can't invite a user with owner role")
|
||||
default:
|
||||
}
|
||||
@@ -348,7 +149,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newUser := &User{
|
||||
newUser := &types.User{
|
||||
Id: idpUser.ID,
|
||||
Role: invitedRole,
|
||||
AutoGroups: invite.AutoGroups,
|
||||
@@ -373,19 +174,19 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
return newUser.ToUserInfo(idpUser, account.Settings)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||
return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id)
|
||||
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||
return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
|
||||
}
|
||||
|
||||
// GetUser looks up a user by provided authorization claims.
|
||||
// It will also create an account if didn't exist for this user before.
|
||||
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
|
||||
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) {
|
||||
accountID, userID, err := am.GetAccountIDFromToken(ctx, claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -409,7 +210,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
|
||||
|
||||
// ListUsers returns lists of all users under the account.
|
||||
// It doesn't populate user information such as email or name.
|
||||
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -418,7 +219,7 @@ func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*User, 0, len(account.Users))
|
||||
users := make([]*types.User, 0, len(account.Users))
|
||||
for _, item := range account.Users {
|
||||
users = append(users, item)
|
||||
}
|
||||
@@ -426,7 +227,7 @@ func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *Account, initiatorUserID string, targetUser *User) {
|
||||
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *types.Account, initiatorUserID string, targetUser *types.User) {
|
||||
meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt}
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta)
|
||||
delete(account.Users, targetUser.Id)
|
||||
@@ -458,12 +259,12 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
||||
return status.Errorf(status.NotFound, "target user not found")
|
||||
}
|
||||
|
||||
if targetUser.Role == UserRoleOwner {
|
||||
if targetUser.Role == types.UserRoleOwner {
|
||||
return status.Errorf(status.PermissionDenied, "unable to delete a user with owner role")
|
||||
}
|
||||
|
||||
// disable deleting integration user if the initiator is not admin service user
|
||||
if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser {
|
||||
if targetUser.Issued == types.UserIssuedIntegration && !executingUser.IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "only integration service user can delete this user")
|
||||
}
|
||||
|
||||
@@ -480,7 +281,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
||||
return am.deleteRegularUser(ctx, account, initiatorUserID, targetUserID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error {
|
||||
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *types.Account, initiatorUserID, targetUserID string) error {
|
||||
meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -500,7 +301,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) {
|
||||
func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *types.Account) (bool, error) {
|
||||
peers, err := account.FindUserPeers(targetUserID)
|
||||
if err != nil {
|
||||
return false, status.Errorf(status.Internal, "failed to find user peers")
|
||||
@@ -560,7 +361,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
|
||||
}
|
||||
|
||||
// CreatePAT creates a new PAT for the given user
|
||||
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -591,7 +392,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
||||
}
|
||||
|
||||
pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id)
|
||||
pat, err := types.CreateNewPAT(tokenName, expiresIn, executingUser.Id)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
|
||||
}
|
||||
@@ -660,13 +461,13 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
|
||||
}
|
||||
|
||||
// GetPAT returns a specific PAT from a user
|
||||
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) {
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -685,13 +486,13 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
|
||||
}
|
||||
|
||||
// GetAllPATs returns all PATs for a user
|
||||
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) {
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -700,7 +501,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG))
|
||||
pats := make([]*types.PersonalAccessToken, 0, len(targetUser.PATsG))
|
||||
for _, pat := range targetUser.PATsG {
|
||||
pats = append(pats, pat.Copy())
|
||||
}
|
||||
@@ -709,13 +510,13 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
|
||||
}
|
||||
|
||||
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
|
||||
func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) {
|
||||
func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) {
|
||||
return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound
|
||||
}
|
||||
|
||||
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
|
||||
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
|
||||
func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
|
||||
func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) {
|
||||
if update == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
@@ -723,7 +524,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists)
|
||||
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*types.User{update}, addIfNotExists)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -738,7 +539,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
|
||||
// SaveOrAddUsers updates existing users or adds new users to the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) {
|
||||
func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) {
|
||||
if len(updates) == 0 {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
@@ -757,7 +558,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations")
|
||||
}
|
||||
|
||||
updatedUsers := make([]*UserInfo, 0, len(updates))
|
||||
updatedUsers := make([]*types.UserInfo, 0, len(updates))
|
||||
var (
|
||||
expiredPeers []*nbpeer.Peer
|
||||
userIDs []string
|
||||
@@ -808,7 +609,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
peerGroupsAdded := make(map[string][]string)
|
||||
peerGroupsRemoved := make(map[string][]string)
|
||||
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
|
||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
// need force update all auto groups in any case they will not be duplicated
|
||||
peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
|
||||
peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
|
||||
@@ -851,7 +652,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
|
||||
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() {
|
||||
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, transferredOwnerRole bool) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
if oldUser.IsBlocked() != newUser.IsBlocked() {
|
||||
@@ -880,11 +681,11 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() {
|
||||
func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() {
|
||||
var eventsToStore []func()
|
||||
if newUser.AutoGroups != nil {
|
||||
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
|
||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||
removedGroups := util.Difference(oldUser.AutoGroups, newUser.AutoGroups)
|
||||
addedGroups := util.Difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||
|
||||
removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved)
|
||||
eventsToStore = append(eventsToStore, removedEvents...)
|
||||
@@ -895,7 +696,7 @@ func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, in
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() {
|
||||
func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() {
|
||||
var eventsToStore []func()
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
@@ -922,7 +723,7 @@ func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, ini
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() {
|
||||
func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() {
|
||||
var eventsToStore []func()
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
@@ -952,10 +753,10 @@ func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context,
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool {
|
||||
if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner {
|
||||
func handleOwnerRoleTransfer(account *types.Account, initiatorUser, update *types.User) bool {
|
||||
if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner {
|
||||
newInitiatorUser := initiatorUser.Copy()
|
||||
newInitiatorUser.Role = UserRoleAdmin
|
||||
newInitiatorUser.Role = types.UserRoleAdmin
|
||||
account.Users[initiatorUser.Id] = newInitiatorUser
|
||||
return true
|
||||
}
|
||||
@@ -965,7 +766,7 @@ func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool
|
||||
// getUserInfo retrieves the UserInfo for a given User and Account.
|
||||
// If the AccountManager has a non-nil idpManager and the User is not a service user,
|
||||
// it will attempt to look up the UserData from the cache.
|
||||
func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) {
|
||||
func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *types.User, account *types.Account) (*types.UserInfo, error) {
|
||||
if !isNil(am.idpManager) && !user.IsServiceUser {
|
||||
userData, err := am.lookupUserInCache(ctx, user.Id, account)
|
||||
if err != nil {
|
||||
@@ -977,23 +778,23 @@ func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, acc
|
||||
}
|
||||
|
||||
// validateUserUpdate validates the update operation for a user.
|
||||
func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error {
|
||||
func validateUserUpdate(account *types.Account, initiatorUser, oldUser, update *types.User) error {
|
||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
||||
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||
}
|
||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role {
|
||||
return status.Errorf(status.PermissionDenied, "admins can't change their role")
|
||||
}
|
||||
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role {
|
||||
if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
||||
return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
|
||||
}
|
||||
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
|
||||
if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
|
||||
return status.Errorf(status.PermissionDenied, "unable to block owner user")
|
||||
}
|
||||
if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role {
|
||||
if initiatorUser.Role == types.UserRoleAdmin && update.Role == types.UserRoleOwner && update.Role != oldUser.Role {
|
||||
return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
|
||||
}
|
||||
if oldUser.IsServiceUser && update.Role == UserRoleOwner {
|
||||
if oldUser.IsServiceUser && update.Role == types.UserRoleOwner {
|
||||
return status.Errorf(status.PermissionDenied, "can't update a service user with owner role")
|
||||
}
|
||||
|
||||
@@ -1012,7 +813,7 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User)
|
||||
}
|
||||
|
||||
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
||||
func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*Account, error) {
|
||||
func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*types.Account, error) {
|
||||
start := time.Now()
|
||||
unlock := am.Store.AcquireGlobalLock(ctx)
|
||||
defer unlock()
|
||||
@@ -1039,7 +840,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u
|
||||
|
||||
userObj := account.Users[userID]
|
||||
|
||||
if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == UserRoleOwner {
|
||||
if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == types.UserRoleOwner {
|
||||
account.Domain = lowerDomain
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
@@ -1052,7 +853,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u
|
||||
|
||||
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
|
||||
// based on provided user role.
|
||||
func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) {
|
||||
func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) {
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1068,7 +869,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
users := make(map[string]userLoggedInOnce, len(account.Users))
|
||||
usersFromIntegration := make([]*idp.UserData, 0)
|
||||
for _, user := range account.Users {
|
||||
if user.Issued == UserIssuedIntegration {
|
||||
if user.Issued == types.UserIssuedIntegration {
|
||||
key := user.IntegrationReference.CacheKey(accountID, user.Id)
|
||||
info, err := am.externalCacheManager.Get(am.ctx, key)
|
||||
if err != nil {
|
||||
@@ -1092,7 +893,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
queriedUsers = append(queriedUsers, usersFromIntegration...)
|
||||
}
|
||||
|
||||
userInfos := make([]*UserInfo, 0)
|
||||
userInfos := make([]*types.UserInfo, 0)
|
||||
|
||||
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
|
||||
if len(queriedUsers) == 0 {
|
||||
@@ -1116,7 +917,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
continue
|
||||
}
|
||||
|
||||
var info *UserInfo
|
||||
var info *types.UserInfo
|
||||
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
|
||||
info, err = localUser.ToUserInfo(queriedUser, account.Settings)
|
||||
if err != nil {
|
||||
@@ -1136,16 +937,16 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
}
|
||||
}
|
||||
|
||||
info = &UserInfo{
|
||||
info = &types.UserInfo{
|
||||
ID: localUser.Id,
|
||||
Email: "",
|
||||
Name: name,
|
||||
Role: string(localUser.Role),
|
||||
AutoGroups: localUser.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
Status: string(types.UserStatusActive),
|
||||
IsServiceUser: localUser.IsServiceUser,
|
||||
NonDeletable: localUser.NonDeletable,
|
||||
Permissions: UserPermissions{DashboardView: dashboardViewPermissions},
|
||||
Permissions: types.UserPermissions{DashboardView: dashboardViewPermissions},
|
||||
}
|
||||
}
|
||||
userInfos = append(userInfos, info)
|
||||
@@ -1155,7 +956,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
// expireAndUpdatePeers expires all peers of the given user and updates them in the account
|
||||
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error {
|
||||
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *types.Account, peers []*nbpeer.Peer) error {
|
||||
var peerIDs []string
|
||||
for _, peer := range peers {
|
||||
// nolint:staticcheck
|
||||
@@ -1260,13 +1061,13 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
continue
|
||||
}
|
||||
|
||||
if targetUser.Role == UserRoleOwner {
|
||||
if targetUser.Role == types.UserRoleOwner {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("unable to delete a user: %s with owner role", targetUserID))
|
||||
continue
|
||||
}
|
||||
|
||||
// disable deleting integration user if the initiator is not admin service user
|
||||
if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser {
|
||||
if targetUser.Issued == types.UserIssuedIntegration && !executingUser.IsServiceUser {
|
||||
allErrors = errors.Join(allErrors, errors.New("only integration service user can delete this user"))
|
||||
continue
|
||||
}
|
||||
@@ -1301,7 +1102,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
return allErrors
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) {
|
||||
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *types.Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) {
|
||||
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
|
||||
@@ -1419,7 +1220,7 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa
|
||||
}
|
||||
|
||||
// areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account.
|
||||
func areUsersLinkedToPeers(account *Account, userIDs []string) bool {
|
||||
func areUsersLinkedToPeers(account *types.Account, userIDs []string) bool {
|
||||
for _, peer := range account.Peers {
|
||||
if slices.Contains(userIDs, peer.UserID) {
|
||||
return true
|
||||
|
||||
@@ -10,8 +10,12 @@ import (
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -41,11 +45,15 @@ const (
|
||||
)
|
||||
|
||||
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -82,14 +90,18 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
account.Users[mockTargetUserId] = &types.User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: false,
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -104,14 +116,18 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
account.Users[mockTargetUserId] = &types.User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: true,
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -130,11 +146,15 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -149,11 +169,15 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -168,19 +192,23 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_DeletePAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
mockTokenID1: {
|
||||
ID: mockTokenID1,
|
||||
HashedToken: mockToken1,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -204,20 +232,24 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_GetPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
mockTokenID1: {
|
||||
ID: mockTokenID1,
|
||||
HashedToken: mockToken1,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -237,13 +269,17 @@ func TestUser_GetPAT(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_GetAllPATs(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
mockTokenID1: {
|
||||
ID: mockTokenID1,
|
||||
HashedToken: mockToken1,
|
||||
@@ -254,7 +290,7 @@ func TestUser_GetAllPATs(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -274,14 +310,14 @@ func TestUser_GetAllPATs(t *testing.T) {
|
||||
|
||||
func TestUser_Copy(t *testing.T) {
|
||||
// this is an imaginary case which will never be in DB this way
|
||||
user := User{
|
||||
user := types.User{
|
||||
Id: "userId",
|
||||
AccountID: "accountId",
|
||||
Role: "role",
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "servicename",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"pat1": {
|
||||
ID: "pat1",
|
||||
Name: "First PAT",
|
||||
@@ -340,11 +376,15 @@ func validateStruct(s interface{}) (err error) {
|
||||
}
|
||||
|
||||
func TestUser_CreateServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -366,26 +406,30 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
assert.NotNil(t, account.Users[user.ID])
|
||||
assert.True(t, account.Users[user.ID].IsServiceUser)
|
||||
assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName)
|
||||
assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role)
|
||||
assert.Equal(t, types.UserRole(mockRole), account.Users[user.ID].Role)
|
||||
assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups)
|
||||
assert.Equal(t, map[string]*PersonalAccessToken{}, account.Users[user.ID].PATs)
|
||||
assert.Equal(t, map[string]*types.PersonalAccessToken{}, account.Users[user.ID].PATs)
|
||||
|
||||
assert.Zero(t, user.Email)
|
||||
assert.True(t, user.IsServiceUser)
|
||||
assert.Equal(t, "active", user.Status)
|
||||
|
||||
_, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, UserRoleOwner, mockServiceUserName, false, nil)
|
||||
_, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, types.UserRoleOwner, mockServiceUserName, false, nil)
|
||||
if err == nil {
|
||||
t.Fatal("should return error when creating service user with owner role")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -395,7 +439,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{
|
||||
user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
|
||||
Name: mockServiceUserName,
|
||||
Role: mockRole,
|
||||
IsServiceUser: true,
|
||||
@@ -413,7 +457,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
assert.Equal(t, 2, len(account.Users))
|
||||
assert.True(t, account.Users[user.ID].IsServiceUser)
|
||||
assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName)
|
||||
assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role)
|
||||
assert.Equal(t, types.UserRole(mockRole), account.Users[user.ID].Role)
|
||||
assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups)
|
||||
|
||||
assert.Equal(t, mockServiceUserName, user.Name)
|
||||
@@ -423,11 +467,15 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -437,7 +485,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
_, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{
|
||||
_, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
|
||||
Name: mockServiceUserName,
|
||||
Role: mockRole,
|
||||
IsServiceUser: false,
|
||||
@@ -448,11 +496,15 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_InviteNewUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -495,7 +547,7 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
am.idpManager = &idpMock
|
||||
|
||||
// test if new invite with regular role works
|
||||
_, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{
|
||||
_, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
|
||||
Name: mockServiceUserName,
|
||||
Role: mockRole,
|
||||
Email: "test@teste.com",
|
||||
@@ -506,9 +558,9 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
assert.NoErrorf(t, err, "Invite user should not throw error")
|
||||
|
||||
// test if new invite with owner role fails
|
||||
_, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{
|
||||
_, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
|
||||
Name: mockServiceUserName,
|
||||
Role: string(UserRoleOwner),
|
||||
Role: string(types.UserRoleOwner),
|
||||
Email: "test2@teste.com",
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
@@ -520,13 +572,13 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serviceUser *User
|
||||
serviceUser *types.User
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrMessage string
|
||||
}{
|
||||
{
|
||||
name: "Can delete service user",
|
||||
serviceUser: &User{
|
||||
serviceUser: &types.User{
|
||||
Id: mockServiceUserID,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: mockServiceUserName,
|
||||
@@ -535,7 +587,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Cannot delete non-deletable service user",
|
||||
serviceUser: &User{
|
||||
serviceUser: &types.User{
|
||||
Id: mockServiceUserID,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: mockServiceUserName,
|
||||
@@ -548,11 +600,16 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = tt.serviceUser
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -580,11 +637,15 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -601,38 +662,42 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &User{
|
||||
account.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "user2username",
|
||||
}
|
||||
targetId = "user3"
|
||||
account.Users[targetId] = &User{
|
||||
account.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Issued: types.UserIssuedAPI,
|
||||
}
|
||||
targetId = "user4"
|
||||
account.Users[targetId] = &User{
|
||||
account.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedIntegration,
|
||||
Issued: types.UserIssuedIntegration,
|
||||
}
|
||||
|
||||
targetId = "user5"
|
||||
account.Users[targetId] = &User{
|
||||
account.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleOwner,
|
||||
Issued: types.UserIssuedAPI,
|
||||
Role: types.UserRoleOwner,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -683,60 +748,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &User{
|
||||
account.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "user2username",
|
||||
}
|
||||
targetId = "user3"
|
||||
account.Users[targetId] = &User{
|
||||
account.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Issued: types.UserIssuedAPI,
|
||||
}
|
||||
targetId = "user4"
|
||||
account.Users[targetId] = &User{
|
||||
account.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedIntegration,
|
||||
Issued: types.UserIssuedIntegration,
|
||||
}
|
||||
|
||||
targetId = "user5"
|
||||
account.Users[targetId] = &User{
|
||||
account.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleOwner,
|
||||
Issued: types.UserIssuedAPI,
|
||||
Role: types.UserRoleOwner,
|
||||
}
|
||||
account.Users["user6"] = &User{
|
||||
account.Users["user6"] = &types.User{
|
||||
Id: "user6",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Issued: types.UserIssuedAPI,
|
||||
}
|
||||
account.Users["user7"] = &User{
|
||||
account.Users["user7"] = &types.User{
|
||||
Id: "user7",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Issued: types.UserIssuedAPI,
|
||||
}
|
||||
account.Users["user8"] = &User{
|
||||
account.Users["user8"] = &types.User{
|
||||
Id: "user8",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: types.UserIssuedAPI,
|
||||
Role: types.UserRoleAdmin,
|
||||
}
|
||||
account.Users["user9"] = &User{
|
||||
account.Users["user9"] = &types.User{
|
||||
Id: "user9",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: types.UserIssuedAPI,
|
||||
Role: types.UserRoleAdmin,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -834,11 +903,15 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -863,13 +936,17 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users["normal_user1"] = NewRegularUser("normal_user1")
|
||||
account.Users["normal_user2"] = NewRegularUser("normal_user2")
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users["normal_user1"] = types.NewRegularUser("normal_user1")
|
||||
account.Users["normal_user2"] = types.NewRegularUser("normal_user2")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -901,43 +978,43 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
role UserRole
|
||||
role types.UserRole
|
||||
limitedViewSettings bool
|
||||
expectedDashboardPermissions string
|
||||
}{
|
||||
{
|
||||
name: "Regular user, no limited view settings",
|
||||
role: UserRoleUser,
|
||||
role: types.UserRoleUser,
|
||||
limitedViewSettings: false,
|
||||
expectedDashboardPermissions: "limited",
|
||||
},
|
||||
{
|
||||
name: "Admin user, no limited view settings",
|
||||
role: UserRoleAdmin,
|
||||
role: types.UserRoleAdmin,
|
||||
limitedViewSettings: false,
|
||||
expectedDashboardPermissions: "full",
|
||||
},
|
||||
{
|
||||
name: "Owner, no limited view settings",
|
||||
role: UserRoleOwner,
|
||||
role: types.UserRoleOwner,
|
||||
limitedViewSettings: false,
|
||||
expectedDashboardPermissions: "full",
|
||||
},
|
||||
{
|
||||
name: "Regular user, limited view settings",
|
||||
role: UserRoleUser,
|
||||
role: types.UserRoleUser,
|
||||
limitedViewSettings: true,
|
||||
expectedDashboardPermissions: "blocked",
|
||||
},
|
||||
{
|
||||
name: "Admin user, limited view settings",
|
||||
role: UserRoleAdmin,
|
||||
role: types.UserRoleAdmin,
|
||||
limitedViewSettings: true,
|
||||
expectedDashboardPermissions: "full",
|
||||
},
|
||||
{
|
||||
name: "Owner, limited view settings",
|
||||
role: UserRoleOwner,
|
||||
role: types.UserRoleOwner,
|
||||
limitedViewSettings: true,
|
||||
expectedDashboardPermissions: "full",
|
||||
},
|
||||
@@ -945,13 +1022,18 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
|
||||
account.Users["normal_user1"] = types.NewUser("normal_user1", testCase.role, false, false, "", []string{}, types.UserIssuedAPI)
|
||||
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||
delete(account.Users, mockUserID)
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -976,13 +1058,17 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
externalUser := &User{
|
||||
externalUser := &types.User{
|
||||
Id: "externalUser",
|
||||
Role: UserRoleUser,
|
||||
Issued: UserIssuedIntegration,
|
||||
Role: types.UserRoleUser,
|
||||
Issued: types.UserIssuedIntegration,
|
||||
IntegrationReference: integration_reference.IntegrationReference{
|
||||
ID: 1,
|
||||
IntegrationType: "external",
|
||||
@@ -990,7 +1076,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
}
|
||||
account.Users[externalUser.Id] = externalUser
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -1020,7 +1106,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(infos))
|
||||
var user *UserInfo
|
||||
var user *types.UserInfo
|
||||
for _, info := range infos {
|
||||
if info.ID == externalUser.Id {
|
||||
user = info
|
||||
@@ -1032,24 +1118,28 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
|
||||
func TestUser_IsAdmin(t *testing.T) {
|
||||
|
||||
user := NewAdminUser(mockUserID)
|
||||
user := types.NewAdminUser(mockUserID)
|
||||
assert.True(t, user.HasAdminPower())
|
||||
|
||||
user = NewRegularUser(mockUserID)
|
||||
user = types.NewRegularUser(mockUserID)
|
||||
assert.False(t, user.HasAdminPower())
|
||||
}
|
||||
|
||||
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
account.Users[mockServiceUserID] = &types.User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -1068,17 +1158,20 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
account.Users[mockServiceUserID] = &types.User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@@ -1112,25 +1205,25 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
initiatorID string
|
||||
update *User
|
||||
update *types.User
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "Should_Fail_To_Update_Admin_Role",
|
||||
expectedErr: true,
|
||||
initiatorID: adminUserID,
|
||||
update: &User{
|
||||
update: &types.User{
|
||||
Id: adminUserID,
|
||||
Role: UserRoleUser,
|
||||
Role: types.UserRoleUser,
|
||||
Blocked: false,
|
||||
},
|
||||
}, {
|
||||
name: "Should_Fail_When_Admin_Blocks_Themselves",
|
||||
expectedErr: true,
|
||||
initiatorID: adminUserID,
|
||||
update: &User{
|
||||
update: &types.User{
|
||||
Id: adminUserID,
|
||||
Role: UserRoleAdmin,
|
||||
Role: types.UserRoleAdmin,
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
@@ -1138,9 +1231,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
name: "Should_Fail_To_Update_Non_Existing_User",
|
||||
expectedErr: true,
|
||||
initiatorID: adminUserID,
|
||||
update: &User{
|
||||
update: &types.User{
|
||||
Id: userID,
|
||||
Role: UserRoleAdmin,
|
||||
Role: types.UserRoleAdmin,
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
@@ -1148,9 +1241,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin",
|
||||
expectedErr: true,
|
||||
initiatorID: regularUserID,
|
||||
update: &User{
|
||||
update: &types.User{
|
||||
Id: adminUserID,
|
||||
Role: UserRoleAdmin,
|
||||
Role: types.UserRoleAdmin,
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
@@ -1158,9 +1251,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
name: "Should_Update_User",
|
||||
expectedErr: false,
|
||||
initiatorID: adminUserID,
|
||||
update: &User{
|
||||
update: &types.User{
|
||||
Id: regularUserID,
|
||||
Role: UserRoleAdmin,
|
||||
Role: types.UserRoleAdmin,
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
@@ -1168,9 +1261,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
name: "Should_Transfer_Owner_Role_To_User",
|
||||
expectedErr: false,
|
||||
initiatorID: ownerUserID,
|
||||
update: &User{
|
||||
update: &types.User{
|
||||
Id: adminUserID,
|
||||
Role: UserRoleAdmin,
|
||||
Role: types.UserRoleAdmin,
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
@@ -1178,9 +1271,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
name: "Should_Fail_To_Transfer_Owner_Role_To_Service_User",
|
||||
expectedErr: true,
|
||||
initiatorID: ownerUserID,
|
||||
update: &User{
|
||||
update: &types.User{
|
||||
Id: serviceUserID,
|
||||
Role: UserRoleOwner,
|
||||
Role: types.UserRoleOwner,
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
@@ -1188,9 +1281,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
name: "Should_Fail_To_Update_Owner_User_Role_By_Admin",
|
||||
expectedErr: true,
|
||||
initiatorID: adminUserID,
|
||||
update: &User{
|
||||
update: &types.User{
|
||||
Id: ownerUserID,
|
||||
Role: UserRoleAdmin,
|
||||
Role: types.UserRoleAdmin,
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
@@ -1198,9 +1291,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
name: "Should_Fail_To_Update_Owner_User_Role_By_User",
|
||||
expectedErr: true,
|
||||
initiatorID: regularUserID,
|
||||
update: &User{
|
||||
update: &types.User{
|
||||
Id: ownerUserID,
|
||||
Role: UserRoleAdmin,
|
||||
Role: types.UserRoleAdmin,
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
@@ -1208,9 +1301,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
name: "Should_Fail_To_Update_Owner_User_Role_By_Service_User",
|
||||
expectedErr: true,
|
||||
initiatorID: serviceUserID,
|
||||
update: &User{
|
||||
update: &types.User{
|
||||
Id: ownerUserID,
|
||||
Role: UserRoleAdmin,
|
||||
Role: types.UserRoleAdmin,
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
@@ -1218,9 +1311,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
name: "Should_Fail_To_Update_Owner_Role_By_Admin",
|
||||
expectedErr: true,
|
||||
initiatorID: adminUserID,
|
||||
update: &User{
|
||||
update: &types.User{
|
||||
Id: regularUserID,
|
||||
Role: UserRoleOwner,
|
||||
Role: types.UserRoleOwner,
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
@@ -1228,9 +1321,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
name: "Should_Fail_To_Block_Owner_Role_By_Admin",
|
||||
expectedErr: true,
|
||||
initiatorID: adminUserID,
|
||||
update: &User{
|
||||
update: &types.User{
|
||||
Id: ownerUserID,
|
||||
Role: UserRoleOwner,
|
||||
Role: types.UserRoleOwner,
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
@@ -1246,9 +1339,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
}
|
||||
|
||||
// create other users
|
||||
account.Users[regularUserID] = NewRegularUser(regularUserID)
|
||||
account.Users[adminUserID] = NewAdminUser(adminUserID)
|
||||
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
|
||||
account.Users[regularUserID] = types.NewRegularUser(regularUserID)
|
||||
account.Users[adminUserID] = types.NewAdminUser(adminUserID)
|
||||
account.Users[serviceUserID] = &types.User{IsServiceUser: true, Id: serviceUserID, Role: types.UserRoleAdmin, ServiceUserName: "service"}
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1279,15 +1372,15 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
policy := &Policy{
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1307,11 +1400,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{
|
||||
Id: "regularUser1",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleUser,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: types.UserRoleUser,
|
||||
Issued: types.UserIssuedAPI,
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1330,11 +1423,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{
|
||||
Id: "regularUser1",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleUser,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: types.UserRoleUser,
|
||||
Issued: types.UserIssuedAPI,
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1364,11 +1457,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
|
||||
// create a user and add new peer with the user
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{
|
||||
Id: "regularUser2",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: types.UserRoleAdmin,
|
||||
Issued: types.UserIssuedAPI,
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1390,11 +1483,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{
|
||||
Id: "regularUser2",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: types.UserRoleAdmin,
|
||||
Issued: types.UserIssuedAPI,
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
16
management/server/util/util.go
Normal file
16
management/server/util/util.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package util
|
||||
|
||||
// Difference returns the elements in `a` that aren't in `b`.
|
||||
func Difference(a, b []string) []string {
|
||||
mb := make(map[string]struct{}, len(b))
|
||||
for _, x := range b {
|
||||
mb[x] = struct{}{}
|
||||
}
|
||||
var diff []string
|
||||
for _, x := range a {
|
||||
if _, found := mb[x]; !found {
|
||||
diff = append(diff, x)
|
||||
}
|
||||
}
|
||||
return diff
|
||||
}
|
||||
Reference in New Issue
Block a user