Move client-imported GPL code to separate package

This commit is contained in:
Viktor Liu
2025-10-23 23:18:08 +02:00
parent c20202a6c3
commit 4fd64379da
51 changed files with 369 additions and 345 deletions

View File

@@ -15,27 +15,28 @@ jobs:
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo ""
# Find all directories except the problematic ones and system dirs
FOUND_ISSUES=0
find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ ! -z "$RESULTS" ]; then
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
FOUND_ISSUES=1
else
echo "✓ No problematic dependencies found"
fi
done
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo ""
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages will change license and should not be imported by client or shared code"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
echo ""
echo "✅ All license dependencies are clean"
fi

View File

@@ -29,7 +29,7 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
func TestMain(m *testing.M) {

View File

@@ -26,7 +26,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
func TestJWTEnforcement(t *testing.T) {

View File

@@ -20,8 +20,8 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
"github.com/netbirdio/netbird/version"
)
@@ -341,7 +341,7 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
return nil
}
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth, error) {
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
s.mu.RLock()
jwtExtractor := s.jwtExtractor
s.mu.RUnlock()
@@ -364,7 +364,7 @@ func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth
return &userAuth, nil
}
func (s *Server) hasSSHAccess(userAuth *nbcontext.UserAuth) bool {
func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
return userAuth.UserId != ""
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/netbirdio/netbird/shared/auth"
"math/rand"
"net"
"net/netip"
@@ -1046,7 +1047,7 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
}
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth,
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth auth.UserAuth,
primaryDomain bool,
) error {
if userAuth.Domain == "" {
@@ -1095,7 +1096,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
ctx context.Context,
userAccountID string,
domainAccountID string,
userAuth nbcontext.UserAuth,
userAuth auth.UserAuth,
) error {
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain)
@@ -1114,7 +1115,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain.
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) {
if userAuth.UserId == "" {
return "", fmt.Errorf("user ID is empty")
}
@@ -1145,7 +1146,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
return newAccount.Id, nil
}
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) {
newUser := types.NewRegularUser(userAuth.UserId)
newUser.AccountID = domainAccountID
@@ -1309,7 +1310,7 @@ func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, ac
return newOnboarding, nil
}
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
if userAuth.UserId == "" {
return "", "", errors.New(emptyUserID)
}
@@ -1353,7 +1354,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error {
if userAuth.IsChild || userAuth.IsPAT {
return nil
}
@@ -1511,7 +1512,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
//
// UserAuth IsChild -> checks that account exists
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth auth.UserAuth) (string, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory)
@@ -1590,7 +1591,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
return domainAccountID, cancel, nil
}
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth auth.UserAuth) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
@@ -1638,7 +1639,7 @@ func handleNotFound(err error) error {
return nil
}
func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool {
func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAuth) bool {
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}

View File

@@ -2,6 +2,7 @@ package account
import (
"context"
"github.com/netbirdio/netbird/shared/auth"
"net"
"net/netip"
"time"
@@ -9,7 +10,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
@@ -45,10 +45,10 @@ type Manager interface {
GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error)
AccountExists(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
GetUserByID(ctx context.Context, id string) (*types.User, error)
GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
@@ -120,12 +120,12 @@ type Manager interface {
UpdateAccountPeers(ctx context.Context, accountID string)
BufferUpdateAccountPeers(ctx context.Context, accountID string)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error
GetStore() store.Store
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
SetEphemeralManager(em ephemeral.Manager)
AllowSync(string, uint64) bool
}

View File

@@ -25,7 +25,6 @@ import (
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -42,6 +41,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/auth"
)
func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account *types.Account, userID string) {
@@ -442,7 +442,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
}
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
type initUserParams nbcontext.UserAuth
type initUserParams auth.UserAuth
var (
publicDomain = "public.com"
@@ -465,7 +465,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
testCases := []struct {
name string
inputClaims nbcontext.UserAuth
inputClaims auth.UserAuth
inputInitUserParams initUserParams
inputUpdateAttrs bool
inputUpdateClaimAccount bool
@@ -480,7 +480,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
}{
{
name: "New User With Public Domain",
inputClaims: nbcontext.UserAuth{
inputClaims: auth.UserAuth{
Domain: publicDomain,
UserId: "pub-domain-user",
DomainCategory: types.PublicCategory,
@@ -497,7 +497,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "New User With Unknown Domain",
inputClaims: nbcontext.UserAuth{
inputClaims: auth.UserAuth{
Domain: unknownDomain,
UserId: "unknown-domain-user",
DomainCategory: types.UnknownCategory,
@@ -514,7 +514,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "New User With Private Domain",
inputClaims: nbcontext.UserAuth{
inputClaims: auth.UserAuth{
Domain: privateDomain,
UserId: "pvt-domain-user",
DomainCategory: types.PrivateCategory,
@@ -531,7 +531,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "New Regular User With Existing Private Domain",
inputClaims: nbcontext.UserAuth{
inputClaims: auth.UserAuth{
Domain: privateDomain,
UserId: "new-pvt-domain-user",
DomainCategory: types.PrivateCategory,
@@ -549,7 +549,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "Existing User With Existing Reclassified Private Domain",
inputClaims: nbcontext.UserAuth{
inputClaims: auth.UserAuth{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: types.PrivateCategory,
@@ -566,7 +566,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "Existing Account Id With Existing Reclassified Private Domain",
inputClaims: nbcontext.UserAuth{
inputClaims: auth.UserAuth{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: types.PrivateCategory,
@@ -584,7 +584,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "User With Private Category And Empty Domain",
inputClaims: nbcontext.UserAuth{
inputClaims: auth.UserAuth{
Domain: "",
UserId: "pvt-domain-user",
DomainCategory: types.PrivateCategory,
@@ -613,7 +613,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
require.NoError(t, err, "get init account failed")
if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, nbcontext.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, auth.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
}
@@ -653,7 +653,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
claims := nbcontext.UserAuth{
claims := auth.UserAuth{
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
Domain: domain,
UserId: userId,
@@ -912,13 +912,13 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
}
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
claims := nbcontext.UserAuth{
claims := auth.UserAuth{
Domain: "example.com",
UserId: "pvt-domain-user",
DomainCategory: types.PrivateCategory,
}
publicClaims := nbcontext.UserAuth{
publicClaims := auth.UserAuth{
Domain: "test.com",
UserId: "public-domain-user",
DomainCategory: types.PublicCategory,
@@ -2648,7 +2648,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
t.Run("skip sync for token auth type", func(t *testing.T) {
claims := nbcontext.UserAuth{
claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group3"},
@@ -2663,7 +2663,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("empty jwt groups", func(t *testing.T) {
claims := nbcontext.UserAuth{
claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{},
@@ -2677,7 +2677,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("jwt match existing api group", func(t *testing.T) {
claims := nbcontext.UserAuth{
claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group1"},
@@ -2698,7 +2698,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"]))
claims := nbcontext.UserAuth{
claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group1"},
@@ -2716,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("add jwt group", func(t *testing.T) {
claims := nbcontext.UserAuth{
claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group1", "group2"},
@@ -2730,7 +2730,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("existed group not update", func(t *testing.T) {
claims := nbcontext.UserAuth{
claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group2"},
@@ -2744,7 +2744,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("add new group", func(t *testing.T) {
claims := nbcontext.UserAuth{
claims := auth.UserAuth{
UserId: "user2",
AccountId: "accountID",
Groups: []string{"group1", "group3"},
@@ -2762,7 +2762,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
claims := nbcontext.UserAuth{
claims := auth.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{},
@@ -2777,7 +2777,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
claims := nbcontext.UserAuth{
claims := auth.UserAuth{
UserId: "user2",
AccountId: "accountID",
Groups: []string{},
@@ -3630,7 +3630,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
// Test adding new user to existing account with approval required
newUserID := "new-user-id"
userAuth := nbcontext.UserAuth{
userAuth := auth.UserAuth{
UserId: newUserID,
Domain: "example.com",
DomainCategory: types.PrivateCategory,
@@ -3660,7 +3660,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
}
// Create a domain-based account without user approval
ownerUserAuth := nbcontext.UserAuth{
ownerUserAuth := auth.UserAuth{
UserId: "owner-user",
Domain: "example.com",
DomainCategory: types.PrivateCategory,
@@ -3679,7 +3679,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
// Test adding new user to existing account without approval required
newUserID := "new-user-id"
userAuth := nbcontext.UserAuth{
userAuth := auth.UserAuth{
UserId: newUserID,
Domain: "example.com",
DomainCategory: types.PrivateCategory,

View File

@@ -5,22 +5,22 @@ import (
"crypto/sha256"
"encoding/base64"
"fmt"
"github.com/netbirdio/netbird/shared/auth"
"hash/crc32"
"github.com/golang-jwt/jwt/v5"
"github.com/netbirdio/netbird/base62"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
var _ Manager = (*manager)(nil)
type Manager interface {
ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error)
EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error)
MarkPATUsed(ctx context.Context, tokenID string) error
GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
}
@@ -55,20 +55,20 @@ func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim s
}
}
func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) {
token, err := m.validator.ValidateAndParse(ctx, value)
if err != nil {
return nbcontext.UserAuth{}, nil, err
return auth.UserAuth{}, nil, err
}
userAuth, err := m.extractor.ToUserAuth(token)
if err != nil {
return nbcontext.UserAuth{}, nil, err
return auth.UserAuth{}, nil, err
}
return userAuth, token, err
}
func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) {
if userAuth.IsChild || userAuth.IsPAT {
return userAuth, nil
}

View File

@@ -2,10 +2,10 @@ package auth
import (
"context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/golang-jwt/jwt/v5"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
)
@@ -15,18 +15,18 @@ var (
// @note really dislike this mocking approach but rather than have to do additional test refactoring.
type MockManager struct {
ValidateAndParseTokenFunc func(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
ValidateAndParseTokenFunc func(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error)
EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error)
MarkPATUsedFunc func(ctx context.Context, tokenID string) error
GetPATInfoFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
}
// EnsureUserAccessByJWTGroups implements Manager.
func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) {
if m.EnsureUserAccessByJWTGroupsFunc != nil {
return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token)
}
return nbcontext.UserAuth{}, nil
return auth.UserAuth{}, nil
}
// GetPATInfo implements Manager.
@@ -46,9 +46,9 @@ func (m *MockManager) MarkPATUsed(ctx context.Context, tokenID string) error {
}
// ValidateAndParseToken implements Manager.
func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) {
if m.ValidateAndParseTokenFunc != nil {
return m.ValidateAndParseTokenFunc(ctx, value)
}
return nbcontext.UserAuth{}, &jwt.Token{}, nil
return auth.UserAuth{}, &jwt.Token{}, nil
}

View File

@@ -17,10 +17,10 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/auth"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
nbauth "github.com/netbirdio/netbird/shared/auth"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
@@ -131,7 +131,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
}
// this has been validated and parsed by ValidateAndParseToken
userAuth := nbcontext.UserAuth{
userAuth := nbauth.UserAuth{
AccountId: account.Id,
Domain: domain,
UserId: userId,
@@ -236,7 +236,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
tests := []struct {
name string
tokenFunc func() string
expected *nbcontext.UserAuth // nil indicates expected error
expected *nbauth.UserAuth // nil indicates expected error
}{
{
name: "Valid with custom claims",
@@ -258,7 +258,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
tokenString, _ := token.SignedString(key)
return tokenString
},
expected: &nbcontext.UserAuth{
expected: &nbauth.UserAuth{
UserId: "user-id|123",
AccountId: "account-id|567",
Domain: "http://localhost",
@@ -282,7 +282,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
tokenString, _ := token.SignedString(key)
return tokenString
},
expected: &nbcontext.UserAuth{
expected: &nbauth.UserAuth{
UserId: "user-id|123",
},
},

View File

@@ -4,7 +4,8 @@ import (
"context"
"fmt"
"net/http"
"time"
"github.com/netbirdio/netbird/shared/auth"
)
type key int
@@ -13,45 +14,22 @@ const (
UserAuthContextKey key = iota
)
type UserAuth struct {
// The account id the user is accessing
AccountId string
// The account domain
Domain string
// The account domain category, TBC values
DomainCategory string
// Indicates whether this user was invited, TBC logic
Invited bool
// Indicates whether this is a child account
IsChild bool
// The user id
UserId string
// Last login time for this user
LastLogin time.Time
// The Groups the user belongs to on this account
Groups []string
// Indicates whether this user has authenticated with a Personal Access Token
IsPAT bool
}
func GetUserAuthFromRequest(r *http.Request) (UserAuth, error) {
func GetUserAuthFromRequest(r *http.Request) (auth.UserAuth, error) {
return GetUserAuthFromContext(r.Context())
}
func SetUserAuthInRequest(r *http.Request, userAuth UserAuth) *http.Request {
func SetUserAuthInRequest(r *http.Request, userAuth auth.UserAuth) *http.Request {
return r.WithContext(SetUserAuthInContext(r.Context(), userAuth))
}
func GetUserAuthFromContext(ctx context.Context) (UserAuth, error) {
if userAuth, ok := ctx.Value(UserAuthContextKey).(UserAuth); ok {
func GetUserAuthFromContext(ctx context.Context) (auth.UserAuth, error) {
if userAuth, ok := ctx.Value(UserAuthContextKey).(auth.UserAuth); ok {
return userAuth, nil
}
return UserAuth{}, fmt.Errorf("user auth not in context")
return auth.UserAuth{}, fmt.Errorf("user auth not in context")
}
func SetUserAuthInContext(ctx context.Context, userAuth UserAuth) context.Context {
func SetUserAuthInContext(ctx context.Context, userAuth auth.UserAuth) context.Context {
//nolint
ctx = context.WithValue(ctx, UserIDKey, userAuth.UserId)
//nolint

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -236,7 +237,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: adminUser.Id,
AccountId: accountID,
Domain: "hotmail.com",

View File

@@ -9,9 +9,9 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/types"
)
// dnsSettingsHandler is a handler that returns the DNS settings of the account

View File

@@ -11,13 +11,14 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -107,7 +108,7 @@ func TestDNSSettingsHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id,
AccountId: testingDNSSettingsAccount.Id,
Domain: testingDNSSettingsAccount.Domain,

View File

@@ -19,6 +19,7 @@ import (
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -193,7 +194,7 @@ func TestNameserversHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
AccountId: testNSGroupAccountID,
Domain: "hotmail.com",

View File

@@ -14,11 +14,12 @@ import (
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func initEventsTestData(account string, events ...*activity.Event) *handler {
@@ -188,7 +189,7 @@ func TestEvents_GetEvents(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_account",

View File

@@ -11,10 +11,10 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns groups of the account

View File

@@ -19,12 +19,13 @@ import (
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
var TestPeers = map[string]*nbpeer.Peer{
@@ -122,7 +123,7 @@ func TestGetGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -248,7 +249,7 @@ func TestWriteGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -330,7 +331,7 @@ func TestDeleteGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",

View File

@@ -12,15 +12,15 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/shared/management/status"
nbtypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// handler is a handler that returns networks of the account

View File

@@ -8,10 +8,10 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
type resourceHandler struct {

View File

@@ -7,10 +7,10 @@ import (
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
type routersHandler struct {

View File

@@ -17,9 +17,10 @@ import (
"golang.org/x/exp/maps"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -277,7 +278,7 @@ func TestGetPeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "admin_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -425,7 +426,7 @@ func TestGetAccessiblePeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: tc.callerUserID,
Domain: "hotmail.com",
AccountId: "test_id",
@@ -508,7 +509,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody)))
req.Header.Set("Content-Type", "application/json")
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: tc.callerUserID,
Domain: "hotmail.com",
AccountId: "test_id",

View File

@@ -16,12 +16,13 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/util"
)
@@ -113,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -206,7 +207,7 @@ func TestGetAllCountries(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",

View File

@@ -9,11 +9,11 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)

View File

@@ -10,10 +10,10 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns policy of the account

View File

@@ -14,10 +14,11 @@ import (
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
func initPoliciesTestData(policies ...*types.Policy) *handler {
@@ -103,7 +104,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -267,7 +268,7 @@ func TestPoliciesWritePolicy(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",

View File

@@ -9,9 +9,9 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/management/status"
)

View File

@@ -16,9 +16,10 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -175,7 +176,7 @@ func TestGetPostureCheck(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
@@ -828,7 +829,7 @@ func TestPostureCheckUpdate(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
@@ -493,7 +494,7 @@ func TestRoutesHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testAccountID,

View File

@@ -10,10 +10,10 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns a list of setup keys of the account

View File

@@ -15,10 +15,11 @@ import (
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -163,7 +164,7 @@ func TestSetupKeysHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: adminUser.Id,
Domain: "hotmail.com",
AccountId: "testAccountId",

View File

@@ -8,10 +8,10 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
)
// patHandler is the nameserver group handler of the account

View File

@@ -17,10 +17,11 @@ import (
"github.com/netbirdio/netbird/management/server/util"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -173,7 +174,7 @@ func TestTokenHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -128,7 +129,7 @@ func initUsersTestData() *handler {
return nil
},
GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
GetCurrentUserInfoFunc: func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
switch userAuth.UserId {
case "not-found":
return nil, status.NewUserNotFoundError("not-found")
@@ -225,7 +226,7 @@ func TestGetUsers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -335,7 +336,7 @@ func TestUpdateUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -432,7 +433,7 @@ func TestCreateUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
rr := httptest.NewRecorder()
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -481,7 +482,7 @@ func TestInviteUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -540,7 +541,7 @@ func TestDeleteUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
@@ -565,7 +566,7 @@ func TestCurrentUser(t *testing.T) {
tt := []struct {
name string
expectedStatus int
requestAuth nbcontext.UserAuth
requestAuth auth.UserAuth
expectedResult *api.User
}{
{
@@ -574,27 +575,27 @@ func TestCurrentUser(t *testing.T) {
},
{
name: "user not found",
requestAuth: nbcontext.UserAuth{UserId: "not-found"},
requestAuth: auth.UserAuth{UserId: "not-found"},
expectedStatus: http.StatusNotFound,
},
{
name: "not of account",
requestAuth: nbcontext.UserAuth{UserId: "not-of-account"},
requestAuth: auth.UserAuth{UserId: "not-of-account"},
expectedStatus: http.StatusForbidden,
},
{
name: "blocked user",
requestAuth: nbcontext.UserAuth{UserId: "blocked-user"},
requestAuth: auth.UserAuth{UserId: "blocked-user"},
expectedStatus: http.StatusForbidden,
},
{
name: "service user",
requestAuth: nbcontext.UserAuth{UserId: "service-user"},
requestAuth: auth.UserAuth{UserId: "service-user"},
expectedStatus: http.StatusForbidden,
},
{
name: "owner",
requestAuth: nbcontext.UserAuth{UserId: "owner"},
requestAuth: auth.UserAuth{UserId: "owner"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "owner",
@@ -613,7 +614,7 @@ func TestCurrentUser(t *testing.T) {
},
{
name: "regular user",
requestAuth: nbcontext.UserAuth{UserId: "regular-user"},
requestAuth: auth.UserAuth{UserId: "regular-user"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "regular-user",
@@ -632,7 +633,7 @@ func TestCurrentUser(t *testing.T) {
},
{
name: "admin user",
requestAuth: nbcontext.UserAuth{UserId: "admin-user"},
requestAuth: auth.UserAuth{UserId: "admin-user"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "admin-user",
@@ -651,7 +652,7 @@ func TestCurrentUser(t *testing.T) {
},
{
name: "restricted user",
requestAuth: nbcontext.UserAuth{UserId: "restricted-user"},
requestAuth: auth.UserAuth{UserId: "restricted-user"},
expectedStatus: http.StatusOK,
expectedResult: &api.User{
Id: "restricted-user",
@@ -783,7 +784,7 @@ func TestApproveUserEndpoint(t *testing.T) {
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
require.NoError(t, err)
userAuth := nbcontext.UserAuth{
userAuth := auth.UserAuth{
AccountId: existingAccountID,
UserId: tc.requestingUser.Id,
}
@@ -841,7 +842,7 @@ func TestRejectUserEndpoint(t *testing.T) {
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
require.NoError(t, err)
userAuth := nbcontext.UserAuth{
userAuth := auth.UserAuth{
AccountId: existingAccountID,
UserId: tc.requestingUser.Id,
}

View File

@@ -10,22 +10,23 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/auth"
serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
type EnsureAccountFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) error
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
authManager auth.Manager
authManager serverauth.Manager
ensureAccount EnsureAccountFunc
getUserFromUserAuth GetUserFromUserAuthFunc
syncUserJWTGroups SyncUserJWTGroupsFunc
@@ -33,7 +34,7 @@ type AuthMiddleware struct {
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(
authManager auth.Manager,
authManager serverauth.Manager,
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
@@ -53,18 +54,18 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return
}
auth := strings.Split(r.Header.Get("Authorization"), " ")
authType := strings.ToLower(auth[0])
authHeader := strings.Split(r.Header.Get("Authorization"), " ")
authType := strings.ToLower(authHeader[0])
// fallback to token when receive pat as bearer
if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") {
if len(authHeader) >= 2 && authType == "bearer" && strings.HasPrefix(authHeader[1], "nbp_") {
authType = "token"
auth[0] = authType
authHeader[0] = authType
}
switch authType {
case "bearer":
request, err := m.checkJWTFromRequest(r, auth)
request, err := m.checkJWTFromRequest(r, authHeader)
if err != nil {
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
@@ -73,7 +74,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
h.ServeHTTP(w, request)
case "token":
request, err := m.checkPATFromRequest(r, auth)
request, err := m.checkPATFromRequest(r, authHeader)
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
@@ -88,8 +89,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
token, err := getTokenFromJWTRequest(auth)
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
token, err := getTokenFromJWTRequest(authHeaderParts)
// If an error occurs, call the error handler and return an error
if err != nil {
@@ -139,8 +140,8 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
token, err := getTokenFromPATRequest(auth)
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
token, err := getTokenFromPATRequest(authHeaderParts)
if err != nil {
return r, fmt.Errorf("error extracting token: %w", err)
}
@@ -159,7 +160,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
return r, err
}
userAuth := nbcontext.UserAuth{
userAuth := auth.UserAuth{
UserId: user.Id,
AccountId: user.AccountID,
Domain: accDomain,

View File

@@ -12,11 +12,12 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/auth"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
nbauth "github.com/netbirdio/netbird/shared/auth"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
const (
@@ -61,9 +62,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
func mockValidateAndParseToken(_ context.Context, token string) (nbauth.UserAuth, *jwt.Token, error) {
if token == JWT {
return nbcontext.UserAuth{
return nbauth.UserAuth{
UserId: userID,
AccountId: accountID,
Domain: testAccount.Domain,
@@ -77,7 +78,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
Valid: true,
}, nil
}
return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid")
return nbauth.UserAuth{}, nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(_ context.Context, token string) error {
@@ -87,7 +88,7 @@ func mockMarkPATUsed(_ context.Context, token string) error {
return fmt.Errorf("Should never get reached")
}
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbauth.UserAuth, token *jwt.Token) (nbauth.UserAuth, error) {
if userAuth.IsChild || userAuth.IsPAT {
return userAuth, nil
}
@@ -183,13 +184,13 @@ func TestAuthMiddleware_Handler(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
)
@@ -226,13 +227,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name string
path string
authHeader string
expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status
expectedUserAuth *nbauth.UserAuth // nil expects 401 response status
}{
{
name: "Valid PAT Token",
path: "/test",
authHeader: "Token " + PAT,
expectedUserAuth: &nbcontext.UserAuth{
expectedUserAuth: &nbauth.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
@@ -244,7 +245,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name: "Valid PAT Token accesses child",
path: "/test?account=xyz",
authHeader: "Token " + PAT,
expectedUserAuth: &nbcontext.UserAuth{
expectedUserAuth: &nbauth.UserAuth{
AccountId: "xyz",
UserId: userID,
Domain: testAccount.Domain,
@@ -257,7 +258,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name: "Valid JWT Token",
path: "/test",
authHeader: "Bearer " + JWT,
expectedUserAuth: &nbcontext.UserAuth{
expectedUserAuth: &nbauth.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
@@ -269,7 +270,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
name: "Valid JWT Token with child",
path: "/test?account=xyz",
authHeader: "Bearer " + JWT,
expectedUserAuth: &nbcontext.UserAuth{
expectedUserAuth: &nbauth.UserAuth{
AccountId: "xyz",
UserId: userID,
Domain: testAccount.Domain,
@@ -288,13 +289,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
)

View File

@@ -13,8 +13,7 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
serverauth "github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
http2 "github.com/netbirdio/netbird/management/server/http"
@@ -28,6 +27,7 @@ import (
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/auth"
)
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
@@ -68,8 +68,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
}
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
authManagerMock := &auth.MockManager{
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
authManagerMock := &serverauth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
MarkPATUsedFunc: authManager.MarkPATUsed,
@@ -114,8 +114,8 @@ func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.Up
}
}
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
userAuth := nbcontext.UserAuth{}
func mockValidateAndParseToken(_ context.Context, token string) (auth.UserAuth, *jwt.Token, error) {
userAuth := auth.UserAuth{}
switch token {
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":

View File

@@ -1,138 +1,137 @@
package idp
import (
"context"
"testing"
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func TestNewPocketIdManager(t *testing.T) {
type test struct {
name string
inputConfig PocketIdClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
type test struct {
name string
inputConfig PocketIdClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
defaultTestConfig := PocketIdClientConfig{
APIToken: "api_token",
ManagementEndpoint: "http://localhost",
}
defaultTestConfig := PocketIdClientConfig{
APIToken: "api_token",
ManagementEndpoint: "http://localhost",
}
tests := []test{
{
name: "Good Configuration",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
},
{
name: "Missing ManagementEndpoint",
inputConfig: PocketIdClientConfig{
APIToken: defaultTestConfig.APIToken,
ManagementEndpoint: "",
},
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
},
{
name: "Missing APIToken",
inputConfig: PocketIdClientConfig{
APIToken: "",
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
},
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
},
}
tests := []test{
{
name: "Good Configuration",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
},
{
name: "Missing ManagementEndpoint",
inputConfig: PocketIdClientConfig{
APIToken: defaultTestConfig.APIToken,
ManagementEndpoint: "",
},
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
},
{
name: "Missing APIToken",
inputConfig: PocketIdClientConfig{
APIToken: "",
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
},
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
})
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
})
}
}
func TestPocketID_GetUserDataByID(t *testing.T) {
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
md := AppMetadata{WTAccountID: "acc1"}
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
require.NoError(t, err)
assert.Equal(t, "u1", got.ID)
assert.Equal(t, "user1@example.com", got.Email)
assert.Equal(t, "User One", got.Name)
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
md := AppMetadata{WTAccountID: "acc1"}
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
require.NoError(t, err)
assert.Equal(t, "u1", got.ID)
assert.Equal(t, "user1@example.com", got.Email)
assert.Equal(t, "User One", got.Name)
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
}
func TestPocketID_GetAccount_WithPagination(t *testing.T) {
// Single page response with two users
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
// Single page response with two users
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
users, err := mgr.GetAccount(context.Background(), "accX")
require.NoError(t, err)
require.Len(t, users, 2)
assert.Equal(t, "u1", users[0].ID)
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
assert.Equal(t, "u2", users[1].ID)
users, err := mgr.GetAccount(context.Background(), "accX")
require.NoError(t, err)
require.Len(t, users, 2)
assert.Equal(t, "u1", users[0].ID)
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
assert.Equal(t, "u2", users[1].ID)
}
func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
accounts, err := mgr.GetAllAccounts(context.Background())
require.NoError(t, err)
require.Len(t, accounts[UnsetAccountID], 2)
accounts, err := mgr.GetAllAccounts(context.Background())
require.NoError(t, err)
require.Len(t, accounts[UnsetAccountID], 2)
}
func TestPocketID_CreateUser(t *testing.T) {
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
require.NoError(t, err)
assert.Equal(t, "newid", ud.ID)
assert.Equal(t, "new@example.com", ud.Email)
assert.Equal(t, "New User", ud.Name)
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
assert.True(t, *ud.AppMetadata.WTPendingInvite)
}
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
require.NoError(t, err)
assert.Equal(t, "newid", ud.ID)
assert.Equal(t, "new@example.com", ud.Email)
assert.Equal(t, "New User", ud.Name)
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
assert.True(t, *ud.AppMetadata.WTPendingInvite)
}
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
}
func TestPocketID_InviteAndDeleteUser(t *testing.T) {
// Same mock for both calls; returns OK with empty JSON
client := &mockHTTPClient{code: 200, resBody: `{}`}
// Same mock for both calls; returns OK with empty JSON
client := &mockHTTPClient{code: 200, resBody: `{}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
err = mgr.InviteUserByID(context.Background(), "u1")
require.NoError(t, err)
err = mgr.InviteUserByID(context.Background(), "u1")
require.NoError(t, err)
err = mgr.DeleteUser(context.Background(), "u1")
require.NoError(t, err)
err = mgr.DeleteUser(context.Background(), "u1")
require.NoError(t, err)
}

View File

@@ -2,6 +2,7 @@ package mock_server
import (
"context"
"github.com/netbirdio/netbird/shared/auth"
"net"
"net/netip"
"time"
@@ -12,7 +13,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
@@ -34,7 +34,7 @@ type MockAccountManager struct {
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
@@ -84,7 +84,7 @@ type MockAccountManager struct {
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
GetDNSDomainFunc func(settings *types.Settings) string
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
@@ -119,7 +119,7 @@ type MockAccountManager struct {
GetStoreFunc func() store.Store
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
GetCurrentUserInfoFunc func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
@@ -469,7 +469,7 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string,
}
// GetUser mock implementation of GetUser from server.AccountManager interface
func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) {
if am.GetUserFromUserAuthFunc != nil {
return am.GetUserFromUserAuthFunc(ctx, userAuth)
}
@@ -674,7 +674,7 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
}
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
if am.GetAccountIDFromUserAuthFunc != nil {
return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)
}
@@ -936,7 +936,7 @@ func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, acco
return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented")
}
func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error {
return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented")
}
@@ -968,7 +968,7 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string
return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented")
}
func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
if am.GetCurrentUserInfoFunc != nil {
return am.GetCurrentUserInfoFunc(ctx, userAuth)
}

View File

@@ -10,8 +10,8 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {

View File

@@ -8,11 +8,11 @@ import (
"github.com/rs/xid"
nbDomain "github.com/netbirdio/netbird/shared/management/domain"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route"
nbDomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
)

View File

@@ -9,8 +9,8 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) {

View File

@@ -5,8 +5,8 @@ import (
"github.com/rs/xid"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/shared/management/http/api"
)
type NetworkRouter struct {

View File

@@ -7,8 +7,8 @@ import (
"regexp"
"github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/shared/management/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)

View File

@@ -1,8 +1,8 @@
package types
import (
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
// RouteFirewallRule a firewall rule applicable for a routed network.

View File

@@ -7,9 +7,9 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
const channelBufferSize = 100

View File

@@ -4,6 +4,8 @@ import (
"context"
"errors"
"fmt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"strings"
"time"
@@ -11,8 +13,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions/modules"
@@ -175,9 +175,9 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t
return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
}
// GetUser looks up a user by provided nbContext.UserAuths.
// GetUser looks up a user by provided auth.UserAuths.
// Expects account to have been created already.
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
return nil, err
@@ -940,7 +940,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
var peerIDs []string
for _, peer := range peers {
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
ctx = context.WithValue(ctx, nbcontext.PeerIDKey, peer.Key)
if peer.UserID == "" {
// we do not want to expire peers that are added via setup key
@@ -1171,7 +1171,7 @@ func validateUserInvite(invite *types.UserInfo) error {
}
// GetCurrentUserInfo retrieves the account's current user info and permissions
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
accountID, userID := userAuth.AccountId, userAuth.UserId
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)

View File

@@ -11,12 +11,12 @@ import (
"golang.org/x/exp/maps"
nbcache "github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -966,7 +966,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
permissionsManager: permissionsManager,
}
claims := nbcontext.UserAuth{
claims := auth.UserAuth{
UserId: mockUserID,
AccountId: mockAccountID,
}
@@ -1573,33 +1573,33 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
tt := []struct {
name string
userAuth nbcontext.UserAuth
userAuth auth.UserAuth
expectedErr error
expectedResult *users.UserInfoWithPermissions
}{
{
name: "not found",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "not-found"},
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "not-found"},
expectedErr: status.NewUserNotFoundError("not-found"),
},
{
name: "not part of account",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account2Owner"},
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account2Owner"},
expectedErr: status.NewUserNotPartOfAccountError(),
},
{
name: "blocked",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "blocked-user"},
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "blocked-user"},
expectedErr: status.NewUserBlockedError(),
},
{
name: "service user",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "service-user"},
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "service-user"},
expectedErr: status.NewPermissionDeniedError(),
},
{
name: "owner user",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account1Owner"},
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account1Owner"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "account1Owner",
@@ -1619,7 +1619,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
},
{
name: "regular user",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "regular-user"},
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "regular-user"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "regular-user",
@@ -1638,7 +1638,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
},
{
name: "admin user",
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "admin-user"},
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "admin-user"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "admin-user",
@@ -1657,7 +1657,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
},
{
name: "settings blocked regular user",
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"},
userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "settings-blocked-user",
@@ -1678,7 +1678,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
{
name: "settings blocked regular user child account",
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true},
userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "settings-blocked-user",
@@ -1698,7 +1698,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
},
{
name: "settings blocked owner user",
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "account2Owner"},
userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "account2Owner"},
expectedResult: &users.UserInfoWithPermissions{
UserInfo: &types.UserInfo{
ID: "account2Owner",

View File

@@ -8,7 +8,7 @@ import (
"github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
)
const (
@@ -87,9 +87,10 @@ func (c ClaimsExtractor) audienceClaim(claimName string) string {
return url
}
func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, error) {
// ToUserAuth extracts user authentication information from a JWT token
func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (auth.UserAuth, error) {
claims := token.Claims.(jwt.MapClaims)
userAuth := nbcontext.UserAuth{}
userAuth := auth.UserAuth{}
userID, ok := claims[c.userIDClaim].(string)
if !ok {
@@ -122,6 +123,7 @@ func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, erro
return userAuth, nil
}
// ToGroups extracts group information from a JWT token
func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string {
claims := token.Claims.(jwt.MapClaims)
userJWTGroups := make([]string, 0)

28
shared/auth/user.go Normal file
View File

@@ -0,0 +1,28 @@
package auth
import (
"time"
)
type UserAuth struct {
// The account id the user is accessing
AccountId string
// The account domain
Domain string
// The account domain category, TBC values
DomainCategory string
// Indicates whether this user was invited, TBC logic
Invited bool
// Indicates whether this is a child account
IsChild bool
// The user id
UserId string
// Last login time for this user
LastLogin time.Time
// The Groups the user belongs to on this account
Groups []string
// Indicates whether this user has authenticated with a Personal Access Token
IsPAT bool
}