mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Move client-imported GPL code to separate package
This commit is contained in:
15
.github/workflows/check-license-dependencies.yml
vendored
15
.github/workflows/check-license-dependencies.yml
vendored
@@ -15,27 +15,28 @@ jobs:
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||
echo ""
|
||||
|
||||
# Find all directories except the problematic ones and system dirs
|
||||
FOUND_ISSUES=0
|
||||
find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ ! -z "$RESULTS" ]; then
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
FOUND_ISSUES=1
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
fi
|
||||
done
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||
|
||||
echo ""
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo ""
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "These packages will change license and should not be imported by client or shared code"
|
||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||
exit 1
|
||||
else
|
||||
echo ""
|
||||
echo "✅ All license dependencies are clean"
|
||||
fi
|
||||
|
||||
@@ -29,7 +29,7 @@ import (
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
||||
@@ -26,7 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
func TestJWTEnforcement(t *testing.T) {
|
||||
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -341,7 +341,7 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth, error) {
|
||||
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
|
||||
s.mu.RLock()
|
||||
jwtExtractor := s.jwtExtractor
|
||||
s.mu.RUnlock()
|
||||
@@ -364,7 +364,7 @@ func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth
|
||||
return &userAuth, nil
|
||||
}
|
||||
|
||||
func (s *Server) hasSSHAccess(userAuth *nbcontext.UserAuth) bool {
|
||||
func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
|
||||
return userAuth.UserId != ""
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -1046,7 +1047,7 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth,
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth auth.UserAuth,
|
||||
primaryDomain bool,
|
||||
) error {
|
||||
if userAuth.Domain == "" {
|
||||
@@ -1095,7 +1096,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
ctx context.Context,
|
||||
userAccountID string,
|
||||
domainAccountID string,
|
||||
userAuth nbcontext.UserAuth,
|
||||
userAuth auth.UserAuth,
|
||||
) error {
|
||||
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
|
||||
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain)
|
||||
@@ -1114,7 +1115,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
|
||||
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||
// otherwise it will create a new account and make it primary account for the domain.
|
||||
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
|
||||
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) {
|
||||
if userAuth.UserId == "" {
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
}
|
||||
@@ -1145,7 +1146,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
|
||||
return newAccount.Id, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
|
||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) {
|
||||
newUser := types.NewRegularUser(userAuth.UserId)
|
||||
newUser.AccountID = domainAccountID
|
||||
|
||||
@@ -1309,7 +1310,7 @@ func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, ac
|
||||
return newOnboarding, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
|
||||
if userAuth.UserId == "" {
|
||||
return "", "", errors.New(emptyUserID)
|
||||
}
|
||||
@@ -1353,7 +1354,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||
// and propagates changes to peers if group propagation is enabled.
|
||||
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
||||
func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
return nil
|
||||
}
|
||||
@@ -1511,7 +1512,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||
//
|
||||
// UserAuth IsChild -> checks that account exists
|
||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
|
||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||
userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory)
|
||||
|
||||
@@ -1590,7 +1591,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
|
||||
return domainAccountID, cancel, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
|
||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
@@ -1638,7 +1639,7 @@ func handleNotFound(err error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool {
|
||||
func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAuth) bool {
|
||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
@@ -9,7 +10,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||
@@ -45,10 +45,10 @@ type Manager interface {
|
||||
GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error)
|
||||
AccountExists(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
|
||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
||||
@@ -120,12 +120,12 @@ type Manager interface {
|
||||
UpdateAccountPeers(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string)
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error
|
||||
GetStore() store.Store
|
||||
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
|
||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
SetEphemeralManager(em ephemeral.Manager)
|
||||
AllowSync(string, uint64) bool
|
||||
}
|
||||
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/cache"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
@@ -42,6 +41,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
)
|
||||
|
||||
func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account *types.Account, userID string) {
|
||||
@@ -442,7 +442,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
type initUserParams nbcontext.UserAuth
|
||||
type initUserParams auth.UserAuth
|
||||
|
||||
var (
|
||||
publicDomain = "public.com"
|
||||
@@ -465,7 +465,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputClaims nbcontext.UserAuth
|
||||
inputClaims auth.UserAuth
|
||||
inputInitUserParams initUserParams
|
||||
inputUpdateAttrs bool
|
||||
inputUpdateClaimAccount bool
|
||||
@@ -480,7 +480,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "New User With Public Domain",
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
inputClaims: auth.UserAuth{
|
||||
Domain: publicDomain,
|
||||
UserId: "pub-domain-user",
|
||||
DomainCategory: types.PublicCategory,
|
||||
@@ -497,7 +497,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "New User With Unknown Domain",
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
inputClaims: auth.UserAuth{
|
||||
Domain: unknownDomain,
|
||||
UserId: "unknown-domain-user",
|
||||
DomainCategory: types.UnknownCategory,
|
||||
@@ -514,7 +514,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "New User With Private Domain",
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
inputClaims: auth.UserAuth{
|
||||
Domain: privateDomain,
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: types.PrivateCategory,
|
||||
@@ -531,7 +531,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "New Regular User With Existing Private Domain",
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
inputClaims: auth.UserAuth{
|
||||
Domain: privateDomain,
|
||||
UserId: "new-pvt-domain-user",
|
||||
DomainCategory: types.PrivateCategory,
|
||||
@@ -549,7 +549,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Existing User With Existing Reclassified Private Domain",
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
inputClaims: auth.UserAuth{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: types.PrivateCategory,
|
||||
@@ -566,7 +566,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Existing Account Id With Existing Reclassified Private Domain",
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
inputClaims: auth.UserAuth{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: types.PrivateCategory,
|
||||
@@ -584,7 +584,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "User With Private Category And Empty Domain",
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
inputClaims: auth.UserAuth{
|
||||
Domain: "",
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: types.PrivateCategory,
|
||||
@@ -613,7 +613,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
require.NoError(t, err, "get init account failed")
|
||||
|
||||
if testCase.inputUpdateAttrs {
|
||||
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, nbcontext.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, auth.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
require.NoError(t, err, "update init user failed")
|
||||
}
|
||||
|
||||
@@ -653,7 +653,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get init account failed")
|
||||
claims := nbcontext.UserAuth{
|
||||
claims := auth.UserAuth{
|
||||
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
|
||||
Domain: domain,
|
||||
UserId: userId,
|
||||
@@ -912,13 +912,13 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||
claims := nbcontext.UserAuth{
|
||||
claims := auth.UserAuth{
|
||||
Domain: "example.com",
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: types.PrivateCategory,
|
||||
}
|
||||
|
||||
publicClaims := nbcontext.UserAuth{
|
||||
publicClaims := auth.UserAuth{
|
||||
Domain: "test.com",
|
||||
UserId: "public-domain-user",
|
||||
DomainCategory: types.PublicCategory,
|
||||
@@ -2648,7 +2648,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
|
||||
|
||||
t.Run("skip sync for token auth type", func(t *testing.T) {
|
||||
claims := nbcontext.UserAuth{
|
||||
claims := auth.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{"group3"},
|
||||
@@ -2663,7 +2663,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("empty jwt groups", func(t *testing.T) {
|
||||
claims := nbcontext.UserAuth{
|
||||
claims := auth.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{},
|
||||
@@ -2677,7 +2677,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("jwt match existing api group", func(t *testing.T) {
|
||||
claims := nbcontext.UserAuth{
|
||||
claims := auth.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{"group1"},
|
||||
@@ -2698,7 +2698,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
account.Users["user1"].AutoGroups = []string{"group1"}
|
||||
assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"]))
|
||||
|
||||
claims := nbcontext.UserAuth{
|
||||
claims := auth.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{"group1"},
|
||||
@@ -2716,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("add jwt group", func(t *testing.T) {
|
||||
claims := nbcontext.UserAuth{
|
||||
claims := auth.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{"group1", "group2"},
|
||||
@@ -2730,7 +2730,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("existed group not update", func(t *testing.T) {
|
||||
claims := nbcontext.UserAuth{
|
||||
claims := auth.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{"group2"},
|
||||
@@ -2744,7 +2744,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("add new group", func(t *testing.T) {
|
||||
claims := nbcontext.UserAuth{
|
||||
claims := auth.UserAuth{
|
||||
UserId: "user2",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{"group1", "group3"},
|
||||
@@ -2762,7 +2762,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
|
||||
claims := nbcontext.UserAuth{
|
||||
claims := auth.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{},
|
||||
@@ -2777,7 +2777,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
|
||||
claims := nbcontext.UserAuth{
|
||||
claims := auth.UserAuth{
|
||||
UserId: "user2",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{},
|
||||
@@ -3630,7 +3630,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
|
||||
|
||||
// Test adding new user to existing account with approval required
|
||||
newUserID := "new-user-id"
|
||||
userAuth := nbcontext.UserAuth{
|
||||
userAuth := auth.UserAuth{
|
||||
UserId: newUserID,
|
||||
Domain: "example.com",
|
||||
DomainCategory: types.PrivateCategory,
|
||||
@@ -3660,7 +3660,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create a domain-based account without user approval
|
||||
ownerUserAuth := nbcontext.UserAuth{
|
||||
ownerUserAuth := auth.UserAuth{
|
||||
UserId: "owner-user",
|
||||
Domain: "example.com",
|
||||
DomainCategory: types.PrivateCategory,
|
||||
@@ -3679,7 +3679,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
|
||||
|
||||
// Test adding new user to existing account without approval required
|
||||
newUserID := "new-user-id"
|
||||
userAuth := nbcontext.UserAuth{
|
||||
userAuth := auth.UserAuth{
|
||||
UserId: newUserID,
|
||||
Domain: "example.com",
|
||||
DomainCategory: types.PrivateCategory,
|
||||
|
||||
@@ -5,22 +5,22 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"hash/crc32"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
var _ Manager = (*manager)(nil)
|
||||
|
||||
type Manager interface {
|
||||
ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
|
||||
EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
|
||||
ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error)
|
||||
EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error)
|
||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||
GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
|
||||
}
|
||||
@@ -55,20 +55,20 @@ func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim s
|
||||
}
|
||||
}
|
||||
|
||||
func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) {
|
||||
token, err := m.validator.ValidateAndParse(ctx, value)
|
||||
if err != nil {
|
||||
return nbcontext.UserAuth{}, nil, err
|
||||
return auth.UserAuth{}, nil, err
|
||||
}
|
||||
|
||||
userAuth, err := m.extractor.ToUserAuth(token)
|
||||
if err != nil {
|
||||
return nbcontext.UserAuth{}, nil, err
|
||||
return auth.UserAuth{}, nil, err
|
||||
}
|
||||
return userAuth, token, err
|
||||
}
|
||||
|
||||
func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
|
||||
func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
@@ -2,10 +2,10 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@@ -15,18 +15,18 @@ var (
|
||||
|
||||
// @note really dislike this mocking approach but rather than have to do additional test refactoring.
|
||||
type MockManager struct {
|
||||
ValidateAndParseTokenFunc func(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
|
||||
EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
|
||||
ValidateAndParseTokenFunc func(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error)
|
||||
EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error)
|
||||
MarkPATUsedFunc func(ctx context.Context, tokenID string) error
|
||||
GetPATInfoFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
|
||||
}
|
||||
|
||||
// EnsureUserAccessByJWTGroups implements Manager.
|
||||
func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
|
||||
func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) {
|
||||
if m.EnsureUserAccessByJWTGroupsFunc != nil {
|
||||
return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token)
|
||||
}
|
||||
return nbcontext.UserAuth{}, nil
|
||||
return auth.UserAuth{}, nil
|
||||
}
|
||||
|
||||
// GetPATInfo implements Manager.
|
||||
@@ -46,9 +46,9 @@ func (m *MockManager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
||||
}
|
||||
|
||||
// ValidateAndParseToken implements Manager.
|
||||
func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) {
|
||||
if m.ValidateAndParseTokenFunc != nil {
|
||||
return m.ValidateAndParseTokenFunc(ctx, value)
|
||||
}
|
||||
return nbcontext.UserAuth{}, &jwt.Token{}, nil
|
||||
return auth.UserAuth{}, &jwt.Token{}, nil
|
||||
}
|
||||
|
||||
@@ -17,10 +17,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbauth "github.com/netbirdio/netbird/shared/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
|
||||
@@ -131,7 +131,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
|
||||
}
|
||||
|
||||
// this has been validated and parsed by ValidateAndParseToken
|
||||
userAuth := nbcontext.UserAuth{
|
||||
userAuth := nbauth.UserAuth{
|
||||
AccountId: account.Id,
|
||||
Domain: domain,
|
||||
UserId: userId,
|
||||
@@ -236,7 +236,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenFunc func() string
|
||||
expected *nbcontext.UserAuth // nil indicates expected error
|
||||
expected *nbauth.UserAuth // nil indicates expected error
|
||||
}{
|
||||
{
|
||||
name: "Valid with custom claims",
|
||||
@@ -258,7 +258,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
||||
tokenString, _ := token.SignedString(key)
|
||||
return tokenString
|
||||
},
|
||||
expected: &nbcontext.UserAuth{
|
||||
expected: &nbauth.UserAuth{
|
||||
UserId: "user-id|123",
|
||||
AccountId: "account-id|567",
|
||||
Domain: "http://localhost",
|
||||
@@ -282,7 +282,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
||||
tokenString, _ := token.SignedString(key)
|
||||
return tokenString
|
||||
},
|
||||
expected: &nbcontext.UserAuth{
|
||||
expected: &nbauth.UserAuth{
|
||||
UserId: "user-id|123",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -4,7 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
)
|
||||
|
||||
type key int
|
||||
@@ -13,45 +14,22 @@ const (
|
||||
UserAuthContextKey key = iota
|
||||
)
|
||||
|
||||
type UserAuth struct {
|
||||
// The account id the user is accessing
|
||||
AccountId string
|
||||
// The account domain
|
||||
Domain string
|
||||
// The account domain category, TBC values
|
||||
DomainCategory string
|
||||
// Indicates whether this user was invited, TBC logic
|
||||
Invited bool
|
||||
// Indicates whether this is a child account
|
||||
IsChild bool
|
||||
|
||||
// The user id
|
||||
UserId string
|
||||
// Last login time for this user
|
||||
LastLogin time.Time
|
||||
// The Groups the user belongs to on this account
|
||||
Groups []string
|
||||
|
||||
// Indicates whether this user has authenticated with a Personal Access Token
|
||||
IsPAT bool
|
||||
}
|
||||
|
||||
func GetUserAuthFromRequest(r *http.Request) (UserAuth, error) {
|
||||
func GetUserAuthFromRequest(r *http.Request) (auth.UserAuth, error) {
|
||||
return GetUserAuthFromContext(r.Context())
|
||||
}
|
||||
|
||||
func SetUserAuthInRequest(r *http.Request, userAuth UserAuth) *http.Request {
|
||||
func SetUserAuthInRequest(r *http.Request, userAuth auth.UserAuth) *http.Request {
|
||||
return r.WithContext(SetUserAuthInContext(r.Context(), userAuth))
|
||||
}
|
||||
|
||||
func GetUserAuthFromContext(ctx context.Context) (UserAuth, error) {
|
||||
if userAuth, ok := ctx.Value(UserAuthContextKey).(UserAuth); ok {
|
||||
func GetUserAuthFromContext(ctx context.Context) (auth.UserAuth, error) {
|
||||
if userAuth, ok := ctx.Value(UserAuthContextKey).(auth.UserAuth); ok {
|
||||
return userAuth, nil
|
||||
}
|
||||
return UserAuth{}, fmt.Errorf("user auth not in context")
|
||||
return auth.UserAuth{}, fmt.Errorf("user auth not in context")
|
||||
}
|
||||
|
||||
func SetUserAuthInContext(ctx context.Context, userAuth UserAuth) context.Context {
|
||||
func SetUserAuthInContext(ctx context.Context, userAuth auth.UserAuth) context.Context {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, UserIDKey, userAuth.UserId)
|
||||
//nolint
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -236,7 +237,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: adminUser.Id,
|
||||
AccountId: accountID,
|
||||
Domain: "hotmail.com",
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// dnsSettingsHandler is a handler that returns the DNS settings of the account
|
||||
|
||||
@@ -11,13 +11,14 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
@@ -107,7 +108,7 @@ func TestDNSSettingsHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id,
|
||||
AccountId: testingDNSSettingsAccount.Id,
|
||||
Domain: testingDNSSettingsAccount.Domain,
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
@@ -193,7 +194,7 @@ func TestNameserversHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
AccountId: testNSGroupAccountID,
|
||||
Domain: "hotmail.com",
|
||||
|
||||
@@ -14,11 +14,12 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func initEventsTestData(account string, events ...*activity.Event) *handler {
|
||||
@@ -188,7 +189,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_account",
|
||||
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// handler is a handler that returns groups of the account
|
||||
|
||||
@@ -19,12 +19,13 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
var TestPeers = map[string]*nbpeer.Peer{
|
||||
@@ -122,7 +123,7 @@ func TestGetGroup(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -248,7 +249,7 @@ func TestWriteGroup(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -330,7 +331,7 @@ func TestDeleteGroup(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
|
||||
@@ -12,15 +12,15 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// handler is a handler that returns networks of the account
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
type resourceHandler struct {
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
type routersHandler struct {
|
||||
|
||||
@@ -17,9 +17,10 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -277,7 +278,7 @@ func TestGetPeers(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "admin_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -425,7 +426,7 @@ func TestGetAccessiblePeers(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: tc.callerUserID,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -508,7 +509,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: tc.callerUserID,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
|
||||
@@ -16,12 +16,13 @@ import (
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -113,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -206,7 +207,7 @@ func TestGetAllCountries(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// handler is a handler that returns policy of the account
|
||||
|
||||
@@ -14,10 +14,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func initPoliciesTestData(policies ...*types.Policy) *handler {
|
||||
@@ -103,7 +104,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -267,7 +268,7 @@ func TestPoliciesWritePolicy(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
|
||||
@@ -16,9 +16,10 @@ import (
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -175,7 +176,7 @@ func TestGetPostureCheck(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
@@ -828,7 +829,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -493,7 +494,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: testAccountID,
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// handler is a handler that returns a list of setup keys of the account
|
||||
|
||||
@@ -15,10 +15,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -163,7 +164,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: adminUser.Id,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "testAccountId",
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// patHandler is the nameserver group handler of the account
|
||||
|
||||
@@ -17,10 +17,11 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -173,7 +174,7 @@ func TestTokenHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -128,7 +129,7 @@ func initUsersTestData() *handler {
|
||||
|
||||
return nil
|
||||
},
|
||||
GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||
GetCurrentUserInfoFunc: func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||
switch userAuth.UserId {
|
||||
case "not-found":
|
||||
return nil, status.NewUserNotFoundError("not-found")
|
||||
@@ -225,7 +226,7 @@ func TestGetUsers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
@@ -335,7 +336,7 @@ func TestUpdateUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
@@ -432,7 +433,7 @@ func TestCreateUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
rr := httptest.NewRecorder()
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
@@ -481,7 +482,7 @@ func TestInviteUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = mux.SetURLVars(req, tc.requestVars)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
@@ -540,7 +541,7 @@ func TestDeleteUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = mux.SetURLVars(req, tc.requestVars)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
@@ -565,7 +566,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestAuth nbcontext.UserAuth
|
||||
requestAuth auth.UserAuth
|
||||
expectedResult *api.User
|
||||
}{
|
||||
{
|
||||
@@ -574,27 +575,27 @@ func TestCurrentUser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "user not found",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "not-found"},
|
||||
requestAuth: auth.UserAuth{UserId: "not-found"},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "not of account",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "not-of-account"},
|
||||
requestAuth: auth.UserAuth{UserId: "not-of-account"},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "blocked user",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "blocked-user"},
|
||||
requestAuth: auth.UserAuth{UserId: "blocked-user"},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "service user",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "service-user"},
|
||||
requestAuth: auth.UserAuth{UserId: "service-user"},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "owner",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "owner"},
|
||||
requestAuth: auth.UserAuth{UserId: "owner"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: &api.User{
|
||||
Id: "owner",
|
||||
@@ -613,7 +614,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "regular user",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "regular-user"},
|
||||
requestAuth: auth.UserAuth{UserId: "regular-user"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: &api.User{
|
||||
Id: "regular-user",
|
||||
@@ -632,7 +633,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "admin user",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "admin-user"},
|
||||
requestAuth: auth.UserAuth{UserId: "admin-user"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: &api.User{
|
||||
Id: "admin-user",
|
||||
@@ -651,7 +652,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "restricted user",
|
||||
requestAuth: nbcontext.UserAuth{UserId: "restricted-user"},
|
||||
requestAuth: auth.UserAuth{UserId: "restricted-user"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedResult: &api.User{
|
||||
Id: "restricted-user",
|
||||
@@ -783,7 +784,7 @@ func TestApproveUserEndpoint(t *testing.T) {
|
||||
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userAuth := nbcontext.UserAuth{
|
||||
userAuth := auth.UserAuth{
|
||||
AccountId: existingAccountID,
|
||||
UserId: tc.requestingUser.Id,
|
||||
}
|
||||
@@ -841,7 +842,7 @@ func TestRejectUserEndpoint(t *testing.T) {
|
||||
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userAuth := nbcontext.UserAuth{
|
||||
userAuth := auth.UserAuth{
|
||||
AccountId: existingAccountID,
|
||||
UserId: tc.requestingUser.Id,
|
||||
}
|
||||
|
||||
@@ -10,22 +10,23 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
type EnsureAccountFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
|
||||
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) error
|
||||
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
authManager auth.Manager
|
||||
authManager serverauth.Manager
|
||||
ensureAccount EnsureAccountFunc
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
@@ -33,7 +34,7 @@ type AuthMiddleware struct {
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(
|
||||
authManager auth.Manager,
|
||||
authManager serverauth.Manager,
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
@@ -53,18 +54,18 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
auth := strings.Split(r.Header.Get("Authorization"), " ")
|
||||
authType := strings.ToLower(auth[0])
|
||||
authHeader := strings.Split(r.Header.Get("Authorization"), " ")
|
||||
authType := strings.ToLower(authHeader[0])
|
||||
|
||||
// fallback to token when receive pat as bearer
|
||||
if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") {
|
||||
if len(authHeader) >= 2 && authType == "bearer" && strings.HasPrefix(authHeader[1], "nbp_") {
|
||||
authType = "token"
|
||||
auth[0] = authType
|
||||
authHeader[0] = authType
|
||||
}
|
||||
|
||||
switch authType {
|
||||
case "bearer":
|
||||
request, err := m.checkJWTFromRequest(r, auth)
|
||||
request, err := m.checkJWTFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
@@ -73,7 +74,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
|
||||
h.ServeHTTP(w, request)
|
||||
case "token":
|
||||
request, err := m.checkPATFromRequest(r, auth)
|
||||
request, err := m.checkPATFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
@@ -88,8 +89,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
||||
token, err := getTokenFromJWTRequest(auth)
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
@@ -139,8 +140,8 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
@@ -159,7 +160,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
|
||||
return r, err
|
||||
}
|
||||
|
||||
userAuth := nbcontext.UserAuth{
|
||||
userAuth := auth.UserAuth{
|
||||
UserId: user.Id,
|
||||
AccountId: user.AccountID,
|
||||
Domain: accDomain,
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
nbauth "github.com/netbirdio/netbird/shared/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -61,9 +62,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
|
||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbauth.UserAuth, *jwt.Token, error) {
|
||||
if token == JWT {
|
||||
return nbcontext.UserAuth{
|
||||
return nbauth.UserAuth{
|
||||
UserId: userID,
|
||||
AccountId: accountID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -77,7 +78,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
|
||||
Valid: true,
|
||||
}, nil
|
||||
}
|
||||
return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid")
|
||||
return nbauth.UserAuth{}, nil, fmt.Errorf("JWT invalid")
|
||||
}
|
||||
|
||||
func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
@@ -87,7 +88,7 @@ func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
return fmt.Errorf("Should never get reached")
|
||||
}
|
||||
|
||||
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
|
||||
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbauth.UserAuth, token *jwt.Token) (nbauth.UserAuth, error) {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
return userAuth, nil
|
||||
}
|
||||
@@ -183,13 +184,13 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
)
|
||||
@@ -226,13 +227,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
name string
|
||||
path string
|
||||
authHeader string
|
||||
expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status
|
||||
expectedUserAuth *nbauth.UserAuth // nil expects 401 response status
|
||||
}{
|
||||
{
|
||||
name: "Valid PAT Token",
|
||||
path: "/test",
|
||||
authHeader: "Token " + PAT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
expectedUserAuth: &nbauth.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -244,7 +245,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
name: "Valid PAT Token accesses child",
|
||||
path: "/test?account=xyz",
|
||||
authHeader: "Token " + PAT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
expectedUserAuth: &nbauth.UserAuth{
|
||||
AccountId: "xyz",
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -257,7 +258,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
name: "Valid JWT Token",
|
||||
path: "/test",
|
||||
authHeader: "Bearer " + JWT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
expectedUserAuth: &nbauth.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -269,7 +270,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
name: "Valid JWT Token with child",
|
||||
path: "/test?account=xyz",
|
||||
authHeader: "Bearer " + JWT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
expectedUserAuth: &nbauth.UserAuth{
|
||||
AccountId: "xyz",
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
@@ -288,13 +289,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
)
|
||||
|
||||
@@ -13,8 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
http2 "github.com/netbirdio/netbird/management/server/http"
|
||||
@@ -28,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
)
|
||||
|
||||
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
|
||||
@@ -68,8 +68,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
}
|
||||
|
||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||
authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
authManagerMock := &auth.MockManager{
|
||||
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
authManagerMock := &serverauth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
|
||||
MarkPATUsedFunc: authManager.MarkPATUsed,
|
||||
@@ -114,8 +114,8 @@ func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.Up
|
||||
}
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
userAuth := nbcontext.UserAuth{}
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (auth.UserAuth, *jwt.Token, error) {
|
||||
userAuth := auth.UserAuth{}
|
||||
|
||||
switch token {
|
||||
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
|
||||
|
||||
@@ -1,138 +1,137 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
|
||||
func TestNewPocketIdManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig PocketIdClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig PocketIdClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := PocketIdClientConfig{
|
||||
APIToken: "api_token",
|
||||
ManagementEndpoint: "http://localhost",
|
||||
}
|
||||
defaultTestConfig := PocketIdClientConfig{
|
||||
APIToken: "api_token",
|
||||
ManagementEndpoint: "http://localhost",
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
},
|
||||
{
|
||||
name: "Missing ManagementEndpoint",
|
||||
inputConfig: PocketIdClientConfig{
|
||||
APIToken: defaultTestConfig.APIToken,
|
||||
ManagementEndpoint: "",
|
||||
},
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
},
|
||||
{
|
||||
name: "Missing APIToken",
|
||||
inputConfig: PocketIdClientConfig{
|
||||
APIToken: "",
|
||||
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
|
||||
},
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
},
|
||||
}
|
||||
tests := []test{
|
||||
{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
},
|
||||
{
|
||||
name: "Missing ManagementEndpoint",
|
||||
inputConfig: PocketIdClientConfig{
|
||||
APIToken: defaultTestConfig.APIToken,
|
||||
ManagementEndpoint: "",
|
||||
},
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
},
|
||||
{
|
||||
name: "Missing APIToken",
|
||||
inputConfig: PocketIdClientConfig{
|
||||
APIToken: "",
|
||||
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
|
||||
},
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
|
||||
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
|
||||
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPocketID_GetUserDataByID(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
md := AppMetadata{WTAccountID: "acc1"}
|
||||
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "u1", got.ID)
|
||||
assert.Equal(t, "user1@example.com", got.Email)
|
||||
assert.Equal(t, "User One", got.Name)
|
||||
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
|
||||
md := AppMetadata{WTAccountID: "acc1"}
|
||||
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "u1", got.ID)
|
||||
assert.Equal(t, "user1@example.com", got.Email)
|
||||
assert.Equal(t, "User One", got.Name)
|
||||
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
|
||||
}
|
||||
|
||||
func TestPocketID_GetAccount_WithPagination(t *testing.T) {
|
||||
// Single page response with two users
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||
// Single page response with two users
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
users, err := mgr.GetAccount(context.Background(), "accX")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, 2)
|
||||
assert.Equal(t, "u1", users[0].ID)
|
||||
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
|
||||
assert.Equal(t, "u2", users[1].ID)
|
||||
users, err := mgr.GetAccount(context.Background(), "accX")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, 2)
|
||||
assert.Equal(t, "u1", users[0].ID)
|
||||
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
|
||||
assert.Equal(t, "u2", users[1].ID)
|
||||
}
|
||||
|
||||
func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
accounts, err := mgr.GetAllAccounts(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, accounts[UnsetAccountID], 2)
|
||||
accounts, err := mgr.GetAllAccounts(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, accounts[UnsetAccountID], 2)
|
||||
}
|
||||
|
||||
func TestPocketID_CreateUser(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
|
||||
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "newid", ud.ID)
|
||||
assert.Equal(t, "new@example.com", ud.Email)
|
||||
assert.Equal(t, "New User", ud.Name)
|
||||
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
|
||||
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
|
||||
assert.True(t, *ud.AppMetadata.WTPendingInvite)
|
||||
}
|
||||
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
|
||||
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "newid", ud.ID)
|
||||
assert.Equal(t, "new@example.com", ud.Email)
|
||||
assert.Equal(t, "New User", ud.Name)
|
||||
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
|
||||
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
|
||||
assert.True(t, *ud.AppMetadata.WTPendingInvite)
|
||||
}
|
||||
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
|
||||
}
|
||||
|
||||
func TestPocketID_InviteAndDeleteUser(t *testing.T) {
|
||||
// Same mock for both calls; returns OK with empty JSON
|
||||
client := &mockHTTPClient{code: 200, resBody: `{}`}
|
||||
// Same mock for both calls; returns OK with empty JSON
|
||||
client := &mockHTTPClient{code: 200, resBody: `{}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
err = mgr.InviteUserByID(context.Background(), "u1")
|
||||
require.NoError(t, err)
|
||||
err = mgr.InviteUserByID(context.Background(), "u1")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.DeleteUser(context.Background(), "u1")
|
||||
require.NoError(t, err)
|
||||
err = mgr.DeleteUser(context.Background(), "u1")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package mock_server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
@@ -12,7 +13,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||
@@ -34,7 +34,7 @@ type MockAccountManager struct {
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
||||
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
@@ -84,7 +84,7 @@ type MockAccountManager struct {
|
||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||
GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
|
||||
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
||||
GetDNSDomainFunc func(settings *types.Settings) string
|
||||
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
@@ -119,7 +119,7 @@ type MockAccountManager struct {
|
||||
GetStoreFunc func() store.Store
|
||||
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error
|
||||
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
GetCurrentUserInfoFunc func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
||||
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
|
||||
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
@@ -469,7 +469,7 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string,
|
||||
}
|
||||
|
||||
// GetUser mock implementation of GetUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) {
|
||||
if am.GetUserFromUserAuthFunc != nil {
|
||||
return am.GetUserFromUserAuthFunc(ctx, userAuth)
|
||||
}
|
||||
@@ -674,7 +674,7 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
|
||||
if am.GetAccountIDFromUserAuthFunc != nil {
|
||||
return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)
|
||||
}
|
||||
@@ -936,7 +936,7 @@ func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, acco
|
||||
return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error {
|
||||
return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented")
|
||||
}
|
||||
|
||||
@@ -968,7 +968,7 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||
func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||
if am.GetCurrentUserInfoFunc != nil {
|
||||
return am.GetCurrentUserInfoFunc(ctx, userAuth)
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
nbDomain "github.com/netbirdio/netbird/shared/management/domain"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
nbDomain "github.com/netbirdio/netbird/shared/management/domain"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) {
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
type NetworkRouter struct {
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"regexp"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// RouteFirewallRule a firewall rule applicable for a routed network.
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const channelBufferSize = 100
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -11,8 +13,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
@@ -175,9 +175,9 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t
|
||||
return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
|
||||
}
|
||||
|
||||
// GetUser looks up a user by provided nbContext.UserAuths.
|
||||
// GetUser looks up a user by provided auth.UserAuths.
|
||||
// Expects account to have been created already.
|
||||
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
|
||||
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -940,7 +940,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
var peerIDs []string
|
||||
for _, peer := range peers {
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
|
||||
ctx = context.WithValue(ctx, nbcontext.PeerIDKey, peer.Key)
|
||||
|
||||
if peer.UserID == "" {
|
||||
// we do not want to expire peers that are added via setup key
|
||||
@@ -1171,7 +1171,7 @@ func validateUserInvite(invite *types.UserInfo) error {
|
||||
}
|
||||
|
||||
// GetCurrentUserInfo retrieves the account's current user info and permissions
|
||||
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
|
||||
@@ -11,12 +11,12 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -966,7 +966,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
|
||||
claims := nbcontext.UserAuth{
|
||||
claims := auth.UserAuth{
|
||||
UserId: mockUserID,
|
||||
AccountId: mockAccountID,
|
||||
}
|
||||
@@ -1573,33 +1573,33 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
userAuth nbcontext.UserAuth
|
||||
userAuth auth.UserAuth
|
||||
expectedErr error
|
||||
expectedResult *users.UserInfoWithPermissions
|
||||
}{
|
||||
{
|
||||
name: "not found",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "not-found"},
|
||||
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "not-found"},
|
||||
expectedErr: status.NewUserNotFoundError("not-found"),
|
||||
},
|
||||
{
|
||||
name: "not part of account",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account2Owner"},
|
||||
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account2Owner"},
|
||||
expectedErr: status.NewUserNotPartOfAccountError(),
|
||||
},
|
||||
{
|
||||
name: "blocked",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "blocked-user"},
|
||||
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "blocked-user"},
|
||||
expectedErr: status.NewUserBlockedError(),
|
||||
},
|
||||
{
|
||||
name: "service user",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "service-user"},
|
||||
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "service-user"},
|
||||
expectedErr: status.NewPermissionDeniedError(),
|
||||
},
|
||||
{
|
||||
name: "owner user",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account1Owner"},
|
||||
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account1Owner"},
|
||||
expectedResult: &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "account1Owner",
|
||||
@@ -1619,7 +1619,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "regular user",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "regular-user"},
|
||||
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "regular-user"},
|
||||
expectedResult: &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "regular-user",
|
||||
@@ -1638,7 +1638,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "admin user",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "admin-user"},
|
||||
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "admin-user"},
|
||||
expectedResult: &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "admin-user",
|
||||
@@ -1657,7 +1657,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "settings blocked regular user",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"},
|
||||
userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"},
|
||||
expectedResult: &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "settings-blocked-user",
|
||||
@@ -1678,7 +1678,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
|
||||
{
|
||||
name: "settings blocked regular user child account",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true},
|
||||
userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true},
|
||||
expectedResult: &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "settings-blocked-user",
|
||||
@@ -1698,7 +1698,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "settings blocked owner user",
|
||||
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "account2Owner"},
|
||||
userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "account2Owner"},
|
||||
expectedResult: &users.UserInfoWithPermissions{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "account2Owner",
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -87,9 +87,10 @@ func (c ClaimsExtractor) audienceClaim(claimName string) string {
|
||||
return url
|
||||
}
|
||||
|
||||
func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, error) {
|
||||
// ToUserAuth extracts user authentication information from a JWT token
|
||||
func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (auth.UserAuth, error) {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
userAuth := nbcontext.UserAuth{}
|
||||
userAuth := auth.UserAuth{}
|
||||
|
||||
userID, ok := claims[c.userIDClaim].(string)
|
||||
if !ok {
|
||||
@@ -122,6 +123,7 @@ func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, erro
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
// ToGroups extracts group information from a JWT token
|
||||
func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
userJWTGroups := make([]string, 0)
|
||||
28
shared/auth/user.go
Normal file
28
shared/auth/user.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserAuth struct {
|
||||
// The account id the user is accessing
|
||||
AccountId string
|
||||
// The account domain
|
||||
Domain string
|
||||
// The account domain category, TBC values
|
||||
DomainCategory string
|
||||
// Indicates whether this user was invited, TBC logic
|
||||
Invited bool
|
||||
// Indicates whether this is a child account
|
||||
IsChild bool
|
||||
|
||||
// The user id
|
||||
UserId string
|
||||
// Last login time for this user
|
||||
LastLogin time.Time
|
||||
// The Groups the user belongs to on this account
|
||||
Groups []string
|
||||
|
||||
// Indicates whether this user has authenticated with a Personal Access Token
|
||||
IsPAT bool
|
||||
}
|
||||
Reference in New Issue
Block a user