mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
Move client-imported GPL code to separate package
This commit is contained in:
15
.github/workflows/check-license-dependencies.yml
vendored
15
.github/workflows/check-license-dependencies.yml
vendored
@@ -15,27 +15,28 @@ jobs:
|
|||||||
- name: Check for problematic license dependencies
|
- name: Check for problematic license dependencies
|
||||||
run: |
|
run: |
|
||||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
# Find all directories except the problematic ones and system dirs
|
# Find all directories except the problematic ones and system dirs
|
||||||
FOUND_ISSUES=0
|
FOUND_ISSUES=0
|
||||||
find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do
|
while IFS= read -r dir; do
|
||||||
echo "=== Checking $dir ==="
|
echo "=== Checking $dir ==="
|
||||||
# Search for problematic imports, excluding test files
|
# Search for problematic imports, excluding test files
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||||
if [ ! -z "$RESULTS" ]; then
|
if [ -n "$RESULTS" ]; then
|
||||||
echo "❌ Found problematic dependencies:"
|
echo "❌ Found problematic dependencies:"
|
||||||
echo "$RESULTS"
|
echo "$RESULTS"
|
||||||
FOUND_ISSUES=1
|
FOUND_ISSUES=1
|
||||||
else
|
else
|
||||||
echo "✓ No problematic dependencies found"
|
echo "✓ No problematic dependencies found"
|
||||||
fi
|
fi
|
||||||
done
|
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||||
|
|
||||||
|
echo ""
|
||||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||||
echo ""
|
|
||||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||||
echo "These packages will change license and should not be imported by client or shared code"
|
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||||
exit 1
|
exit 1
|
||||||
else
|
else
|
||||||
echo ""
|
|
||||||
echo "✅ All license dependencies are clean"
|
echo "✅ All license dependencies are clean"
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ import (
|
|||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/ssh/server"
|
"github.com/netbirdio/netbird/client/ssh/server"
|
||||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/ssh/client"
|
"github.com/netbirdio/netbird/client/ssh/client"
|
||||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJWTEnforcement(t *testing.T) {
|
func TestJWTEnforcement(t *testing.T) {
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
"github.com/netbirdio/netbird/management/server/auth/jwt"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -341,7 +341,7 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth, error) {
|
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
jwtExtractor := s.jwtExtractor
|
jwtExtractor := s.jwtExtractor
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
@@ -364,7 +364,7 @@ func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth
|
|||||||
return &userAuth, nil
|
return &userAuth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) hasSSHAccess(userAuth *nbcontext.UserAuth) bool {
|
func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
|
||||||
return userAuth.UserId != ""
|
return userAuth.UserId != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -1046,7 +1047,7 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
|
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
|
||||||
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth,
|
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth auth.UserAuth,
|
||||||
primaryDomain bool,
|
primaryDomain bool,
|
||||||
) error {
|
) error {
|
||||||
if userAuth.Domain == "" {
|
if userAuth.Domain == "" {
|
||||||
@@ -1095,7 +1096,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userAccountID string,
|
userAccountID string,
|
||||||
domainAccountID string,
|
domainAccountID string,
|
||||||
userAuth nbcontext.UserAuth,
|
userAuth auth.UserAuth,
|
||||||
) error {
|
) error {
|
||||||
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
|
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
|
||||||
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain)
|
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain)
|
||||||
@@ -1114,7 +1115,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
|||||||
|
|
||||||
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||||
// otherwise it will create a new account and make it primary account for the domain.
|
// otherwise it will create a new account and make it primary account for the domain.
|
||||||
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
|
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) {
|
||||||
if userAuth.UserId == "" {
|
if userAuth.UserId == "" {
|
||||||
return "", fmt.Errorf("user ID is empty")
|
return "", fmt.Errorf("user ID is empty")
|
||||||
}
|
}
|
||||||
@@ -1145,7 +1146,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
|
|||||||
return newAccount.Id, nil
|
return newAccount.Id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
|
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) {
|
||||||
newUser := types.NewRegularUser(userAuth.UserId)
|
newUser := types.NewRegularUser(userAuth.UserId)
|
||||||
newUser.AccountID = domainAccountID
|
newUser.AccountID = domainAccountID
|
||||||
|
|
||||||
@@ -1309,7 +1310,7 @@ func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, ac
|
|||||||
return newOnboarding, nil
|
return newOnboarding, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
|
||||||
if userAuth.UserId == "" {
|
if userAuth.UserId == "" {
|
||||||
return "", "", errors.New(emptyUserID)
|
return "", "", errors.New(emptyUserID)
|
||||||
}
|
}
|
||||||
@@ -1353,7 +1354,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
|||||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||||
// and propagates changes to peers if group propagation is enabled.
|
// and propagates changes to peers if group propagation is enabled.
|
||||||
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
||||||
func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error {
|
||||||
if userAuth.IsChild || userAuth.IsPAT {
|
if userAuth.IsChild || userAuth.IsPAT {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1511,7 +1512,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
|||||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||||
//
|
//
|
||||||
// UserAuth IsChild -> checks that account exists
|
// UserAuth IsChild -> checks that account exists
|
||||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
|
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth auth.UserAuth) (string, error) {
|
||||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||||
userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory)
|
userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory)
|
||||||
|
|
||||||
@@ -1590,7 +1591,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
|
|||||||
return domainAccountID, cancel, nil
|
return domainAccountID, cancel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
|
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth auth.UserAuth) (string, error) {
|
||||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||||
@@ -1638,7 +1639,7 @@ func handleNotFound(err error) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool {
|
func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAuth) bool {
|
||||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package account
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
@@ -9,7 +10,6 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||||
@@ -45,10 +45,10 @@ type Manager interface {
|
|||||||
GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error)
|
GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error)
|
||||||
AccountExists(ctx context.Context, accountID string) (bool, error)
|
AccountExists(ctx context.Context, accountID string) (bool, error)
|
||||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||||
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
|
||||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||||
GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
||||||
@@ -120,12 +120,12 @@ type Manager interface {
|
|||||||
UpdateAccountPeers(ctx context.Context, accountID string)
|
UpdateAccountPeers(ctx context.Context, accountID string)
|
||||||
BufferUpdateAccountPeers(ctx context.Context, accountID string)
|
BufferUpdateAccountPeers(ctx context.Context, accountID string)
|
||||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||||
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error
|
||||||
GetStore() store.Store
|
GetStore() store.Store
|
||||||
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
|
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
|
||||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||||
SetEphemeralManager(em ephemeral.Manager)
|
SetEphemeralManager(em ephemeral.Manager)
|
||||||
AllowSync(string, uint64) bool
|
AllowSync(string, uint64) bool
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ import (
|
|||||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/cache"
|
"github.com/netbirdio/netbird/management/server/cache"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
@@ -42,6 +41,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account *types.Account, userID string) {
|
func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account *types.Account, userID string) {
|
||||||
@@ -442,7 +442,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||||
type initUserParams nbcontext.UserAuth
|
type initUserParams auth.UserAuth
|
||||||
|
|
||||||
var (
|
var (
|
||||||
publicDomain = "public.com"
|
publicDomain = "public.com"
|
||||||
@@ -465,7 +465,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
inputClaims nbcontext.UserAuth
|
inputClaims auth.UserAuth
|
||||||
inputInitUserParams initUserParams
|
inputInitUserParams initUserParams
|
||||||
inputUpdateAttrs bool
|
inputUpdateAttrs bool
|
||||||
inputUpdateClaimAccount bool
|
inputUpdateClaimAccount bool
|
||||||
@@ -480,7 +480,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "New User With Public Domain",
|
name: "New User With Public Domain",
|
||||||
inputClaims: nbcontext.UserAuth{
|
inputClaims: auth.UserAuth{
|
||||||
Domain: publicDomain,
|
Domain: publicDomain,
|
||||||
UserId: "pub-domain-user",
|
UserId: "pub-domain-user",
|
||||||
DomainCategory: types.PublicCategory,
|
DomainCategory: types.PublicCategory,
|
||||||
@@ -497,7 +497,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "New User With Unknown Domain",
|
name: "New User With Unknown Domain",
|
||||||
inputClaims: nbcontext.UserAuth{
|
inputClaims: auth.UserAuth{
|
||||||
Domain: unknownDomain,
|
Domain: unknownDomain,
|
||||||
UserId: "unknown-domain-user",
|
UserId: "unknown-domain-user",
|
||||||
DomainCategory: types.UnknownCategory,
|
DomainCategory: types.UnknownCategory,
|
||||||
@@ -514,7 +514,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "New User With Private Domain",
|
name: "New User With Private Domain",
|
||||||
inputClaims: nbcontext.UserAuth{
|
inputClaims: auth.UserAuth{
|
||||||
Domain: privateDomain,
|
Domain: privateDomain,
|
||||||
UserId: "pvt-domain-user",
|
UserId: "pvt-domain-user",
|
||||||
DomainCategory: types.PrivateCategory,
|
DomainCategory: types.PrivateCategory,
|
||||||
@@ -531,7 +531,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "New Regular User With Existing Private Domain",
|
name: "New Regular User With Existing Private Domain",
|
||||||
inputClaims: nbcontext.UserAuth{
|
inputClaims: auth.UserAuth{
|
||||||
Domain: privateDomain,
|
Domain: privateDomain,
|
||||||
UserId: "new-pvt-domain-user",
|
UserId: "new-pvt-domain-user",
|
||||||
DomainCategory: types.PrivateCategory,
|
DomainCategory: types.PrivateCategory,
|
||||||
@@ -549,7 +549,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Existing User With Existing Reclassified Private Domain",
|
name: "Existing User With Existing Reclassified Private Domain",
|
||||||
inputClaims: nbcontext.UserAuth{
|
inputClaims: auth.UserAuth{
|
||||||
Domain: defaultInitAccount.Domain,
|
Domain: defaultInitAccount.Domain,
|
||||||
UserId: defaultInitAccount.UserId,
|
UserId: defaultInitAccount.UserId,
|
||||||
DomainCategory: types.PrivateCategory,
|
DomainCategory: types.PrivateCategory,
|
||||||
@@ -566,7 +566,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Existing Account Id With Existing Reclassified Private Domain",
|
name: "Existing Account Id With Existing Reclassified Private Domain",
|
||||||
inputClaims: nbcontext.UserAuth{
|
inputClaims: auth.UserAuth{
|
||||||
Domain: defaultInitAccount.Domain,
|
Domain: defaultInitAccount.Domain,
|
||||||
UserId: defaultInitAccount.UserId,
|
UserId: defaultInitAccount.UserId,
|
||||||
DomainCategory: types.PrivateCategory,
|
DomainCategory: types.PrivateCategory,
|
||||||
@@ -584,7 +584,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "User With Private Category And Empty Domain",
|
name: "User With Private Category And Empty Domain",
|
||||||
inputClaims: nbcontext.UserAuth{
|
inputClaims: auth.UserAuth{
|
||||||
Domain: "",
|
Domain: "",
|
||||||
UserId: "pvt-domain-user",
|
UserId: "pvt-domain-user",
|
||||||
DomainCategory: types.PrivateCategory,
|
DomainCategory: types.PrivateCategory,
|
||||||
@@ -613,7 +613,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "get init account failed")
|
require.NoError(t, err, "get init account failed")
|
||||||
|
|
||||||
if testCase.inputUpdateAttrs {
|
if testCase.inputUpdateAttrs {
|
||||||
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, nbcontext.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, auth.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||||
require.NoError(t, err, "update init user failed")
|
require.NoError(t, err, "update init user failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -653,7 +653,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
|||||||
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||||
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
require.NoError(t, err, "get init account failed")
|
require.NoError(t, err, "get init account failed")
|
||||||
claims := nbcontext.UserAuth{
|
claims := auth.UserAuth{
|
||||||
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
|
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
@@ -912,13 +912,13 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||||
claims := nbcontext.UserAuth{
|
claims := auth.UserAuth{
|
||||||
Domain: "example.com",
|
Domain: "example.com",
|
||||||
UserId: "pvt-domain-user",
|
UserId: "pvt-domain-user",
|
||||||
DomainCategory: types.PrivateCategory,
|
DomainCategory: types.PrivateCategory,
|
||||||
}
|
}
|
||||||
|
|
||||||
publicClaims := nbcontext.UserAuth{
|
publicClaims := auth.UserAuth{
|
||||||
Domain: "test.com",
|
Domain: "test.com",
|
||||||
UserId: "public-domain-user",
|
UserId: "public-domain-user",
|
||||||
DomainCategory: types.PublicCategory,
|
DomainCategory: types.PublicCategory,
|
||||||
@@ -2648,7 +2648,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
|
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
|
||||||
|
|
||||||
t.Run("skip sync for token auth type", func(t *testing.T) {
|
t.Run("skip sync for token auth type", func(t *testing.T) {
|
||||||
claims := nbcontext.UserAuth{
|
claims := auth.UserAuth{
|
||||||
UserId: "user1",
|
UserId: "user1",
|
||||||
AccountId: "accountID",
|
AccountId: "accountID",
|
||||||
Groups: []string{"group3"},
|
Groups: []string{"group3"},
|
||||||
@@ -2663,7 +2663,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("empty jwt groups", func(t *testing.T) {
|
t.Run("empty jwt groups", func(t *testing.T) {
|
||||||
claims := nbcontext.UserAuth{
|
claims := auth.UserAuth{
|
||||||
UserId: "user1",
|
UserId: "user1",
|
||||||
AccountId: "accountID",
|
AccountId: "accountID",
|
||||||
Groups: []string{},
|
Groups: []string{},
|
||||||
@@ -2677,7 +2677,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("jwt match existing api group", func(t *testing.T) {
|
t.Run("jwt match existing api group", func(t *testing.T) {
|
||||||
claims := nbcontext.UserAuth{
|
claims := auth.UserAuth{
|
||||||
UserId: "user1",
|
UserId: "user1",
|
||||||
AccountId: "accountID",
|
AccountId: "accountID",
|
||||||
Groups: []string{"group1"},
|
Groups: []string{"group1"},
|
||||||
@@ -2698,7 +2698,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
account.Users["user1"].AutoGroups = []string{"group1"}
|
account.Users["user1"].AutoGroups = []string{"group1"}
|
||||||
assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"]))
|
assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"]))
|
||||||
|
|
||||||
claims := nbcontext.UserAuth{
|
claims := auth.UserAuth{
|
||||||
UserId: "user1",
|
UserId: "user1",
|
||||||
AccountId: "accountID",
|
AccountId: "accountID",
|
||||||
Groups: []string{"group1"},
|
Groups: []string{"group1"},
|
||||||
@@ -2716,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add jwt group", func(t *testing.T) {
|
t.Run("add jwt group", func(t *testing.T) {
|
||||||
claims := nbcontext.UserAuth{
|
claims := auth.UserAuth{
|
||||||
UserId: "user1",
|
UserId: "user1",
|
||||||
AccountId: "accountID",
|
AccountId: "accountID",
|
||||||
Groups: []string{"group1", "group2"},
|
Groups: []string{"group1", "group2"},
|
||||||
@@ -2730,7 +2730,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("existed group not update", func(t *testing.T) {
|
t.Run("existed group not update", func(t *testing.T) {
|
||||||
claims := nbcontext.UserAuth{
|
claims := auth.UserAuth{
|
||||||
UserId: "user1",
|
UserId: "user1",
|
||||||
AccountId: "accountID",
|
AccountId: "accountID",
|
||||||
Groups: []string{"group2"},
|
Groups: []string{"group2"},
|
||||||
@@ -2744,7 +2744,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add new group", func(t *testing.T) {
|
t.Run("add new group", func(t *testing.T) {
|
||||||
claims := nbcontext.UserAuth{
|
claims := auth.UserAuth{
|
||||||
UserId: "user2",
|
UserId: "user2",
|
||||||
AccountId: "accountID",
|
AccountId: "accountID",
|
||||||
Groups: []string{"group1", "group3"},
|
Groups: []string{"group1", "group3"},
|
||||||
@@ -2762,7 +2762,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
|
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
|
||||||
claims := nbcontext.UserAuth{
|
claims := auth.UserAuth{
|
||||||
UserId: "user1",
|
UserId: "user1",
|
||||||
AccountId: "accountID",
|
AccountId: "accountID",
|
||||||
Groups: []string{},
|
Groups: []string{},
|
||||||
@@ -2777,7 +2777,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
|
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
|
||||||
claims := nbcontext.UserAuth{
|
claims := auth.UserAuth{
|
||||||
UserId: "user2",
|
UserId: "user2",
|
||||||
AccountId: "accountID",
|
AccountId: "accountID",
|
||||||
Groups: []string{},
|
Groups: []string{},
|
||||||
@@ -3630,7 +3630,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
|
|||||||
|
|
||||||
// Test adding new user to existing account with approval required
|
// Test adding new user to existing account with approval required
|
||||||
newUserID := "new-user-id"
|
newUserID := "new-user-id"
|
||||||
userAuth := nbcontext.UserAuth{
|
userAuth := auth.UserAuth{
|
||||||
UserId: newUserID,
|
UserId: newUserID,
|
||||||
Domain: "example.com",
|
Domain: "example.com",
|
||||||
DomainCategory: types.PrivateCategory,
|
DomainCategory: types.PrivateCategory,
|
||||||
@@ -3660,7 +3660,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a domain-based account without user approval
|
// Create a domain-based account without user approval
|
||||||
ownerUserAuth := nbcontext.UserAuth{
|
ownerUserAuth := auth.UserAuth{
|
||||||
UserId: "owner-user",
|
UserId: "owner-user",
|
||||||
Domain: "example.com",
|
Domain: "example.com",
|
||||||
DomainCategory: types.PrivateCategory,
|
DomainCategory: types.PrivateCategory,
|
||||||
@@ -3679,7 +3679,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
|
|||||||
|
|
||||||
// Test adding new user to existing account without approval required
|
// Test adding new user to existing account without approval required
|
||||||
newUserID := "new-user-id"
|
newUserID := "new-user-id"
|
||||||
userAuth := nbcontext.UserAuth{
|
userAuth := auth.UserAuth{
|
||||||
UserId: newUserID,
|
UserId: newUserID,
|
||||||
Domain: "example.com",
|
Domain: "example.com",
|
||||||
DomainCategory: types.PrivateCategory,
|
DomainCategory: types.PrivateCategory,
|
||||||
|
|||||||
@@ -5,22 +5,22 @@ import (
|
|||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"hash/crc32"
|
"hash/crc32"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/base62"
|
"github.com/netbirdio/netbird/base62"
|
||||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ Manager = (*manager)(nil)
|
var _ Manager = (*manager)(nil)
|
||||||
|
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
|
ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error)
|
||||||
EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
|
EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error)
|
||||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||||
GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
|
GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
|
||||||
}
|
}
|
||||||
@@ -55,20 +55,20 @@ func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
|
func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) {
|
||||||
token, err := m.validator.ValidateAndParse(ctx, value)
|
token, err := m.validator.ValidateAndParse(ctx, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nbcontext.UserAuth{}, nil, err
|
return auth.UserAuth{}, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userAuth, err := m.extractor.ToUserAuth(token)
|
userAuth, err := m.extractor.ToUserAuth(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nbcontext.UserAuth{}, nil, err
|
return auth.UserAuth{}, nil, err
|
||||||
}
|
}
|
||||||
return userAuth, token, err
|
return userAuth, token, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
|
func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) {
|
||||||
if userAuth.IsChild || userAuth.IsPAT {
|
if userAuth.IsChild || userAuth.IsPAT {
|
||||||
return userAuth, nil
|
return userAuth, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,18 +15,18 @@ var (
|
|||||||
|
|
||||||
// @note really dislike this mocking approach but rather than have to do additional test refactoring.
|
// @note really dislike this mocking approach but rather than have to do additional test refactoring.
|
||||||
type MockManager struct {
|
type MockManager struct {
|
||||||
ValidateAndParseTokenFunc func(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
|
ValidateAndParseTokenFunc func(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error)
|
||||||
EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
|
EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error)
|
||||||
MarkPATUsedFunc func(ctx context.Context, tokenID string) error
|
MarkPATUsedFunc func(ctx context.Context, tokenID string) error
|
||||||
GetPATInfoFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
|
GetPATInfoFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnsureUserAccessByJWTGroups implements Manager.
|
// EnsureUserAccessByJWTGroups implements Manager.
|
||||||
func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
|
func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) {
|
||||||
if m.EnsureUserAccessByJWTGroupsFunc != nil {
|
if m.EnsureUserAccessByJWTGroupsFunc != nil {
|
||||||
return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token)
|
return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token)
|
||||||
}
|
}
|
||||||
return nbcontext.UserAuth{}, nil
|
return auth.UserAuth{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPATInfo implements Manager.
|
// GetPATInfo implements Manager.
|
||||||
@@ -46,9 +46,9 @@ func (m *MockManager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValidateAndParseToken implements Manager.
|
// ValidateAndParseToken implements Manager.
|
||||||
func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
|
func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) {
|
||||||
if m.ValidateAndParseTokenFunc != nil {
|
if m.ValidateAndParseTokenFunc != nil {
|
||||||
return m.ValidateAndParseTokenFunc(ctx, value)
|
return m.ValidateAndParseTokenFunc(ctx, value)
|
||||||
}
|
}
|
||||||
return nbcontext.UserAuth{}, &jwt.Token{}, nil
|
return auth.UserAuth{}, &jwt.Token{}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
nbauth "github.com/netbirdio/netbird/shared/auth"
|
||||||
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
|
func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
|
||||||
@@ -131,7 +131,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// this has been validated and parsed by ValidateAndParseToken
|
// this has been validated and parsed by ValidateAndParseToken
|
||||||
userAuth := nbcontext.UserAuth{
|
userAuth := nbauth.UserAuth{
|
||||||
AccountId: account.Id,
|
AccountId: account.Id,
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
@@ -236,7 +236,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
tokenFunc func() string
|
tokenFunc func() string
|
||||||
expected *nbcontext.UserAuth // nil indicates expected error
|
expected *nbauth.UserAuth // nil indicates expected error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Valid with custom claims",
|
name: "Valid with custom claims",
|
||||||
@@ -258,7 +258,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
|||||||
tokenString, _ := token.SignedString(key)
|
tokenString, _ := token.SignedString(key)
|
||||||
return tokenString
|
return tokenString
|
||||||
},
|
},
|
||||||
expected: &nbcontext.UserAuth{
|
expected: &nbauth.UserAuth{
|
||||||
UserId: "user-id|123",
|
UserId: "user-id|123",
|
||||||
AccountId: "account-id|567",
|
AccountId: "account-id|567",
|
||||||
Domain: "http://localhost",
|
Domain: "http://localhost",
|
||||||
@@ -282,7 +282,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
|||||||
tokenString, _ := token.SignedString(key)
|
tokenString, _ := token.SignedString(key)
|
||||||
return tokenString
|
return tokenString
|
||||||
},
|
},
|
||||||
expected: &nbcontext.UserAuth{
|
expected: &nbauth.UserAuth{
|
||||||
UserId: "user-id|123",
|
UserId: "user-id|123",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
type key int
|
type key int
|
||||||
@@ -13,45 +14,22 @@ const (
|
|||||||
UserAuthContextKey key = iota
|
UserAuthContextKey key = iota
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserAuth struct {
|
func GetUserAuthFromRequest(r *http.Request) (auth.UserAuth, error) {
|
||||||
// The account id the user is accessing
|
|
||||||
AccountId string
|
|
||||||
// The account domain
|
|
||||||
Domain string
|
|
||||||
// The account domain category, TBC values
|
|
||||||
DomainCategory string
|
|
||||||
// Indicates whether this user was invited, TBC logic
|
|
||||||
Invited bool
|
|
||||||
// Indicates whether this is a child account
|
|
||||||
IsChild bool
|
|
||||||
|
|
||||||
// The user id
|
|
||||||
UserId string
|
|
||||||
// Last login time for this user
|
|
||||||
LastLogin time.Time
|
|
||||||
// The Groups the user belongs to on this account
|
|
||||||
Groups []string
|
|
||||||
|
|
||||||
// Indicates whether this user has authenticated with a Personal Access Token
|
|
||||||
IsPAT bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUserAuthFromRequest(r *http.Request) (UserAuth, error) {
|
|
||||||
return GetUserAuthFromContext(r.Context())
|
return GetUserAuthFromContext(r.Context())
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetUserAuthInRequest(r *http.Request, userAuth UserAuth) *http.Request {
|
func SetUserAuthInRequest(r *http.Request, userAuth auth.UserAuth) *http.Request {
|
||||||
return r.WithContext(SetUserAuthInContext(r.Context(), userAuth))
|
return r.WithContext(SetUserAuthInContext(r.Context(), userAuth))
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserAuthFromContext(ctx context.Context) (UserAuth, error) {
|
func GetUserAuthFromContext(ctx context.Context) (auth.UserAuth, error) {
|
||||||
if userAuth, ok := ctx.Value(UserAuthContextKey).(UserAuth); ok {
|
if userAuth, ok := ctx.Value(UserAuthContextKey).(auth.UserAuth); ok {
|
||||||
return userAuth, nil
|
return userAuth, nil
|
||||||
}
|
}
|
||||||
return UserAuth{}, fmt.Errorf("user auth not in context")
|
return auth.UserAuth{}, fmt.Errorf("user auth not in context")
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetUserAuthInContext(ctx context.Context, userAuth UserAuth) context.Context {
|
func SetUserAuthInContext(ctx context.Context, userAuth auth.UserAuth) context.Context {
|
||||||
//nolint
|
//nolint
|
||||||
ctx = context.WithValue(ctx, UserIDKey, userAuth.UserId)
|
ctx = context.WithValue(ctx, UserIDKey, userAuth.UserId)
|
||||||
//nolint
|
//nolint
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
@@ -236,7 +237,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: adminUser.Id,
|
UserId: adminUser.Id,
|
||||||
AccountId: accountID,
|
AccountId: accountID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// dnsSettingsHandler is a handler that returns the DNS settings of the account
|
// dnsSettingsHandler is a handler that returns the DNS settings of the account
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
)
|
)
|
||||||
@@ -107,7 +108,7 @@ func TestDNSSettingsHandlers(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id,
|
UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id,
|
||||||
AccountId: testingDNSSettingsAccount.Id,
|
AccountId: testingDNSSettingsAccount.Id,
|
||||||
Domain: testingDNSSettingsAccount.Domain,
|
Domain: testingDNSSettingsAccount.Domain,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
)
|
)
|
||||||
@@ -193,7 +194,7 @@ func TestNameserversHandlers(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: "test_user",
|
UserId: "test_user",
|
||||||
AccountId: testNSGroupAccountID,
|
AccountId: testNSGroupAccountID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
|
|||||||
@@ -14,11 +14,12 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initEventsTestData(account string, events ...*activity.Event) *handler {
|
func initEventsTestData(account string, events ...*activity.Event) *handler {
|
||||||
@@ -188,7 +189,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: "test_user",
|
UserId: "test_user",
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_account",
|
AccountId: "test_account",
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ import (
|
|||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// handler is a handler that returns groups of the account
|
// handler is a handler that returns groups of the account
|
||||||
|
|||||||
@@ -19,12 +19,13 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
var TestPeers = map[string]*nbpeer.Peer{
|
var TestPeers = map[string]*nbpeer.Peer{
|
||||||
@@ -122,7 +123,7 @@ func TestGetGroup(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: "test_user",
|
UserId: "test_user",
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
@@ -248,7 +249,7 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: "test_user",
|
UserId: "test_user",
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
@@ -330,7 +331,7 @@ func TestDeleteGroup(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: "test_user",
|
UserId: "test_user",
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
|
|||||||
@@ -12,15 +12,15 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
|
||||||
"github.com/netbirdio/netbird/management/server/networks"
|
"github.com/netbirdio/netbird/management/server/networks"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handler is a handler that returns networks of the account
|
// handler is a handler that returns networks of the account
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ import (
|
|||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
|
||||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type resourceHandler struct {
|
type resourceHandler struct {
|
||||||
|
|||||||
@@ -7,10 +7,10 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
|
||||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type routersHandler struct {
|
type routersHandler struct {
|
||||||
|
|||||||
@@ -17,9 +17,10 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -277,7 +278,7 @@ func TestGetPeers(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: "admin_user",
|
UserId: "admin_user",
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
@@ -425,7 +426,7 @@ func TestGetAccessiblePeers(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
|
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: tc.callerUserID,
|
UserId: tc.callerUserID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
@@ -508,7 +509,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody)))
|
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: tc.callerUserID,
|
UserId: tc.callerUserID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
|
|||||||
@@ -16,12 +16,13 @@ import (
|
|||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -113,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: "test_user",
|
UserId: "test_user",
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
@@ -206,7 +207,7 @@ func TestGetAllCountries(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: "test_user",
|
UserId: "test_user",
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -10,10 +10,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// handler is a handler that returns policy of the account
|
// handler is a handler that returns policy of the account
|
||||||
|
|||||||
@@ -14,10 +14,11 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initPoliciesTestData(policies ...*types.Policy) *handler {
|
func initPoliciesTestData(policies ...*types.Policy) *handler {
|
||||||
@@ -103,7 +104,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: "test_user",
|
UserId: "test_user",
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
@@ -267,7 +268,7 @@ func TestPoliciesWritePolicy(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: "test_user",
|
UserId: "test_user",
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -16,9 +16,10 @@ import (
|
|||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -175,7 +176,7 @@ func TestGetPostureCheck(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
|
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: "test_user",
|
UserId: "test_user",
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
@@ -828,7 +829,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: "test_user",
|
UserId: "test_user",
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "test_id",
|
AccountId: "test_id",
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
@@ -493,7 +494,7 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: "test_user",
|
UserId: "test_user",
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: testAccountID,
|
AccountId: testAccountID,
|
||||||
|
|||||||
@@ -10,10 +10,10 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// handler is a handler that returns a list of setup keys of the account
|
// handler is a handler that returns a list of setup keys of the account
|
||||||
|
|||||||
@@ -15,10 +15,11 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -163,7 +164,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: adminUser.Id,
|
UserId: adminUser.Id,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: "testAccountId",
|
AccountId: "testAccountId",
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// patHandler is the nameserver group handler of the account
|
// patHandler is the nameserver group handler of the account
|
||||||
|
|||||||
@@ -17,10 +17,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -173,7 +174,7 @@ func TestTokenHandlers(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: existingUserID,
|
UserId: existingUserID,
|
||||||
Domain: testDomain,
|
Domain: testDomain,
|
||||||
AccountId: existingAccountID,
|
AccountId: existingAccountID,
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/users"
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
@@ -128,7 +129,7 @@ func initUsersTestData() *handler {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
|
GetCurrentUserInfoFunc: func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||||
switch userAuth.UserId {
|
switch userAuth.UserId {
|
||||||
case "not-found":
|
case "not-found":
|
||||||
return nil, status.NewUserNotFoundError("not-found")
|
return nil, status.NewUserNotFoundError("not-found")
|
||||||
@@ -225,7 +226,7 @@ func TestGetUsers(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: existingUserID,
|
UserId: existingUserID,
|
||||||
Domain: testDomain,
|
Domain: testDomain,
|
||||||
AccountId: existingAccountID,
|
AccountId: existingAccountID,
|
||||||
@@ -335,7 +336,7 @@ func TestUpdateUser(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: existingUserID,
|
UserId: existingUserID,
|
||||||
Domain: testDomain,
|
Domain: testDomain,
|
||||||
AccountId: existingAccountID,
|
AccountId: existingAccountID,
|
||||||
@@ -432,7 +433,7 @@ func TestCreateUser(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: existingUserID,
|
UserId: existingUserID,
|
||||||
Domain: testDomain,
|
Domain: testDomain,
|
||||||
AccountId: existingAccountID,
|
AccountId: existingAccountID,
|
||||||
@@ -481,7 +482,7 @@ func TestInviteUser(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
req = mux.SetURLVars(req, tc.requestVars)
|
req = mux.SetURLVars(req, tc.requestVars)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: existingUserID,
|
UserId: existingUserID,
|
||||||
Domain: testDomain,
|
Domain: testDomain,
|
||||||
AccountId: existingAccountID,
|
AccountId: existingAccountID,
|
||||||
@@ -540,7 +541,7 @@ func TestDeleteUser(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
req = mux.SetURLVars(req, tc.requestVars)
|
req = mux.SetURLVars(req, tc.requestVars)
|
||||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||||
UserId: existingUserID,
|
UserId: existingUserID,
|
||||||
Domain: testDomain,
|
Domain: testDomain,
|
||||||
AccountId: existingAccountID,
|
AccountId: existingAccountID,
|
||||||
@@ -565,7 +566,7 @@ func TestCurrentUser(t *testing.T) {
|
|||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
expectedStatus int
|
expectedStatus int
|
||||||
requestAuth nbcontext.UserAuth
|
requestAuth auth.UserAuth
|
||||||
expectedResult *api.User
|
expectedResult *api.User
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -574,27 +575,27 @@ func TestCurrentUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "user not found",
|
name: "user not found",
|
||||||
requestAuth: nbcontext.UserAuth{UserId: "not-found"},
|
requestAuth: auth.UserAuth{UserId: "not-found"},
|
||||||
expectedStatus: http.StatusNotFound,
|
expectedStatus: http.StatusNotFound,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "not of account",
|
name: "not of account",
|
||||||
requestAuth: nbcontext.UserAuth{UserId: "not-of-account"},
|
requestAuth: auth.UserAuth{UserId: "not-of-account"},
|
||||||
expectedStatus: http.StatusForbidden,
|
expectedStatus: http.StatusForbidden,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "blocked user",
|
name: "blocked user",
|
||||||
requestAuth: nbcontext.UserAuth{UserId: "blocked-user"},
|
requestAuth: auth.UserAuth{UserId: "blocked-user"},
|
||||||
expectedStatus: http.StatusForbidden,
|
expectedStatus: http.StatusForbidden,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "service user",
|
name: "service user",
|
||||||
requestAuth: nbcontext.UserAuth{UserId: "service-user"},
|
requestAuth: auth.UserAuth{UserId: "service-user"},
|
||||||
expectedStatus: http.StatusForbidden,
|
expectedStatus: http.StatusForbidden,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "owner",
|
name: "owner",
|
||||||
requestAuth: nbcontext.UserAuth{UserId: "owner"},
|
requestAuth: auth.UserAuth{UserId: "owner"},
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedResult: &api.User{
|
expectedResult: &api.User{
|
||||||
Id: "owner",
|
Id: "owner",
|
||||||
@@ -613,7 +614,7 @@ func TestCurrentUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "regular user",
|
name: "regular user",
|
||||||
requestAuth: nbcontext.UserAuth{UserId: "regular-user"},
|
requestAuth: auth.UserAuth{UserId: "regular-user"},
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedResult: &api.User{
|
expectedResult: &api.User{
|
||||||
Id: "regular-user",
|
Id: "regular-user",
|
||||||
@@ -632,7 +633,7 @@ func TestCurrentUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "admin user",
|
name: "admin user",
|
||||||
requestAuth: nbcontext.UserAuth{UserId: "admin-user"},
|
requestAuth: auth.UserAuth{UserId: "admin-user"},
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedResult: &api.User{
|
expectedResult: &api.User{
|
||||||
Id: "admin-user",
|
Id: "admin-user",
|
||||||
@@ -651,7 +652,7 @@ func TestCurrentUser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "restricted user",
|
name: "restricted user",
|
||||||
requestAuth: nbcontext.UserAuth{UserId: "restricted-user"},
|
requestAuth: auth.UserAuth{UserId: "restricted-user"},
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedResult: &api.User{
|
expectedResult: &api.User{
|
||||||
Id: "restricted-user",
|
Id: "restricted-user",
|
||||||
@@ -783,7 +784,7 @@ func TestApproveUserEndpoint(t *testing.T) {
|
|||||||
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
|
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
userAuth := nbcontext.UserAuth{
|
userAuth := auth.UserAuth{
|
||||||
AccountId: existingAccountID,
|
AccountId: existingAccountID,
|
||||||
UserId: tc.requestingUser.Id,
|
UserId: tc.requestingUser.Id,
|
||||||
}
|
}
|
||||||
@@ -841,7 +842,7 @@ func TestRejectUserEndpoint(t *testing.T) {
|
|||||||
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
|
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
userAuth := nbcontext.UserAuth{
|
userAuth := auth.UserAuth{
|
||||||
AccountId: existingAccountID,
|
AccountId: existingAccountID,
|
||||||
UserId: tc.requestingUser.Id,
|
UserId: tc.requestingUser.Id,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,22 +10,23 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
type EnsureAccountFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
|
||||||
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
|
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) error
|
||||||
|
|
||||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||||
|
|
||||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||||
type AuthMiddleware struct {
|
type AuthMiddleware struct {
|
||||||
authManager auth.Manager
|
authManager serverauth.Manager
|
||||||
ensureAccount EnsureAccountFunc
|
ensureAccount EnsureAccountFunc
|
||||||
getUserFromUserAuth GetUserFromUserAuthFunc
|
getUserFromUserAuth GetUserFromUserAuthFunc
|
||||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||||
@@ -33,7 +34,7 @@ type AuthMiddleware struct {
|
|||||||
|
|
||||||
// NewAuthMiddleware instance constructor
|
// NewAuthMiddleware instance constructor
|
||||||
func NewAuthMiddleware(
|
func NewAuthMiddleware(
|
||||||
authManager auth.Manager,
|
authManager serverauth.Manager,
|
||||||
ensureAccount EnsureAccountFunc,
|
ensureAccount EnsureAccountFunc,
|
||||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||||
@@ -53,18 +54,18 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
auth := strings.Split(r.Header.Get("Authorization"), " ")
|
authHeader := strings.Split(r.Header.Get("Authorization"), " ")
|
||||||
authType := strings.ToLower(auth[0])
|
authType := strings.ToLower(authHeader[0])
|
||||||
|
|
||||||
// fallback to token when receive pat as bearer
|
// fallback to token when receive pat as bearer
|
||||||
if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") {
|
if len(authHeader) >= 2 && authType == "bearer" && strings.HasPrefix(authHeader[1], "nbp_") {
|
||||||
authType = "token"
|
authType = "token"
|
||||||
auth[0] = authType
|
authHeader[0] = authType
|
||||||
}
|
}
|
||||||
|
|
||||||
switch authType {
|
switch authType {
|
||||||
case "bearer":
|
case "bearer":
|
||||||
request, err := m.checkJWTFromRequest(r, auth)
|
request, err := m.checkJWTFromRequest(r, authHeader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||||
@@ -73,7 +74,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
|
|
||||||
h.ServeHTTP(w, request)
|
h.ServeHTTP(w, request)
|
||||||
case "token":
|
case "token":
|
||||||
request, err := m.checkPATFromRequest(r, auth)
|
request, err := m.checkPATFromRequest(r, authHeader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||||
@@ -88,8 +89,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CheckJWTFromRequest checks if the JWT is valid
|
// CheckJWTFromRequest checks if the JWT is valid
|
||||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||||
token, err := getTokenFromJWTRequest(auth)
|
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||||
|
|
||||||
// If an error occurs, call the error handler and return an error
|
// If an error occurs, call the error handler and return an error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -139,8 +140,8 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CheckPATFromRequest checks if the PAT is valid
|
// CheckPATFromRequest checks if the PAT is valid
|
||||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||||
token, err := getTokenFromPATRequest(auth)
|
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, fmt.Errorf("error extracting token: %w", err)
|
return r, fmt.Errorf("error extracting token: %w", err)
|
||||||
}
|
}
|
||||||
@@ -159,7 +160,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
|
|||||||
return r, err
|
return r, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userAuth := nbcontext.UserAuth{
|
userAuth := auth.UserAuth{
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
AccountId: user.AccountID,
|
AccountId: user.AccountID,
|
||||||
Domain: accDomain,
|
Domain: accDomain,
|
||||||
|
|||||||
@@ -12,11 +12,12 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
"github.com/netbirdio/netbird/management/server/auth"
|
||||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
|
nbauth "github.com/netbirdio/netbird/shared/auth"
|
||||||
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -61,9 +62,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
|
|||||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
func mockValidateAndParseToken(_ context.Context, token string) (nbauth.UserAuth, *jwt.Token, error) {
|
||||||
if token == JWT {
|
if token == JWT {
|
||||||
return nbcontext.UserAuth{
|
return nbauth.UserAuth{
|
||||||
UserId: userID,
|
UserId: userID,
|
||||||
AccountId: accountID,
|
AccountId: accountID,
|
||||||
Domain: testAccount.Domain,
|
Domain: testAccount.Domain,
|
||||||
@@ -77,7 +78,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
|
|||||||
Valid: true,
|
Valid: true,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid")
|
return nbauth.UserAuth{}, nil, fmt.Errorf("JWT invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockMarkPATUsed(_ context.Context, token string) error {
|
func mockMarkPATUsed(_ context.Context, token string) error {
|
||||||
@@ -87,7 +88,7 @@ func mockMarkPATUsed(_ context.Context, token string) error {
|
|||||||
return fmt.Errorf("Should never get reached")
|
return fmt.Errorf("Should never get reached")
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
|
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbauth.UserAuth, token *jwt.Token) (nbauth.UserAuth, error) {
|
||||||
if userAuth.IsChild || userAuth.IsPAT {
|
if userAuth.IsChild || userAuth.IsPAT {
|
||||||
return userAuth, nil
|
return userAuth, nil
|
||||||
}
|
}
|
||||||
@@ -183,13 +184,13 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
|||||||
|
|
||||||
authMiddleware := NewAuthMiddleware(
|
authMiddleware := NewAuthMiddleware(
|
||||||
mockAuth,
|
mockAuth,
|
||||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||||
return userAuth.AccountId, userAuth.UserId, nil
|
return userAuth.AccountId, userAuth.UserId, nil
|
||||||
},
|
},
|
||||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -226,13 +227,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
path string
|
path string
|
||||||
authHeader string
|
authHeader string
|
||||||
expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status
|
expectedUserAuth *nbauth.UserAuth // nil expects 401 response status
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Valid PAT Token",
|
name: "Valid PAT Token",
|
||||||
path: "/test",
|
path: "/test",
|
||||||
authHeader: "Token " + PAT,
|
authHeader: "Token " + PAT,
|
||||||
expectedUserAuth: &nbcontext.UserAuth{
|
expectedUserAuth: &nbauth.UserAuth{
|
||||||
AccountId: accountID,
|
AccountId: accountID,
|
||||||
UserId: userID,
|
UserId: userID,
|
||||||
Domain: testAccount.Domain,
|
Domain: testAccount.Domain,
|
||||||
@@ -244,7 +245,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
|||||||
name: "Valid PAT Token accesses child",
|
name: "Valid PAT Token accesses child",
|
||||||
path: "/test?account=xyz",
|
path: "/test?account=xyz",
|
||||||
authHeader: "Token " + PAT,
|
authHeader: "Token " + PAT,
|
||||||
expectedUserAuth: &nbcontext.UserAuth{
|
expectedUserAuth: &nbauth.UserAuth{
|
||||||
AccountId: "xyz",
|
AccountId: "xyz",
|
||||||
UserId: userID,
|
UserId: userID,
|
||||||
Domain: testAccount.Domain,
|
Domain: testAccount.Domain,
|
||||||
@@ -257,7 +258,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
|||||||
name: "Valid JWT Token",
|
name: "Valid JWT Token",
|
||||||
path: "/test",
|
path: "/test",
|
||||||
authHeader: "Bearer " + JWT,
|
authHeader: "Bearer " + JWT,
|
||||||
expectedUserAuth: &nbcontext.UserAuth{
|
expectedUserAuth: &nbauth.UserAuth{
|
||||||
AccountId: accountID,
|
AccountId: accountID,
|
||||||
UserId: userID,
|
UserId: userID,
|
||||||
Domain: testAccount.Domain,
|
Domain: testAccount.Domain,
|
||||||
@@ -269,7 +270,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
|||||||
name: "Valid JWT Token with child",
|
name: "Valid JWT Token with child",
|
||||||
path: "/test?account=xyz",
|
path: "/test?account=xyz",
|
||||||
authHeader: "Bearer " + JWT,
|
authHeader: "Bearer " + JWT,
|
||||||
expectedUserAuth: &nbcontext.UserAuth{
|
expectedUserAuth: &nbauth.UserAuth{
|
||||||
AccountId: "xyz",
|
AccountId: "xyz",
|
||||||
UserId: userID,
|
UserId: userID,
|
||||||
Domain: testAccount.Domain,
|
Domain: testAccount.Domain,
|
||||||
@@ -288,13 +289,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
|||||||
|
|
||||||
authMiddleware := NewAuthMiddleware(
|
authMiddleware := NewAuthMiddleware(
|
||||||
mockAuth,
|
mockAuth,
|
||||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||||
return userAuth.AccountId, userAuth.UserId, nil
|
return userAuth.AccountId, userAuth.UserId, nil
|
||||||
},
|
},
|
||||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,8 +13,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
http2 "github.com/netbirdio/netbird/management/server/http"
|
http2 "github.com/netbirdio/netbird/management/server/http"
|
||||||
@@ -28,6 +27,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/users"
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
|
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
|
||||||
@@ -68,8 +68,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
}
|
}
|
||||||
|
|
||||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||||
authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
||||||
authManagerMock := &auth.MockManager{
|
authManagerMock := &serverauth.MockManager{
|
||||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||||
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
|
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
|
||||||
MarkPATUsedFunc: authManager.MarkPATUsed,
|
MarkPATUsedFunc: authManager.MarkPATUsed,
|
||||||
@@ -114,8 +114,8 @@ func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.Up
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
func mockValidateAndParseToken(_ context.Context, token string) (auth.UserAuth, *jwt.Token, error) {
|
||||||
userAuth := nbcontext.UserAuth{}
|
userAuth := auth.UserAuth{}
|
||||||
|
|
||||||
switch token {
|
switch token {
|
||||||
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
|
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
|
||||||
|
|||||||
@@ -1,138 +1,137 @@
|
|||||||
package idp
|
package idp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
func TestNewPocketIdManager(t *testing.T) {
|
func TestNewPocketIdManager(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
inputConfig PocketIdClientConfig
|
inputConfig PocketIdClientConfig
|
||||||
assertErrFunc require.ErrorAssertionFunc
|
assertErrFunc require.ErrorAssertionFunc
|
||||||
assertErrFuncMessage string
|
assertErrFuncMessage string
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultTestConfig := PocketIdClientConfig{
|
defaultTestConfig := PocketIdClientConfig{
|
||||||
APIToken: "api_token",
|
APIToken: "api_token",
|
||||||
ManagementEndpoint: "http://localhost",
|
ManagementEndpoint: "http://localhost",
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []test{
|
tests := []test{
|
||||||
{
|
{
|
||||||
name: "Good Configuration",
|
name: "Good Configuration",
|
||||||
inputConfig: defaultTestConfig,
|
inputConfig: defaultTestConfig,
|
||||||
assertErrFunc: require.NoError,
|
assertErrFunc: require.NoError,
|
||||||
assertErrFuncMessage: "shouldn't return error",
|
assertErrFuncMessage: "shouldn't return error",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Missing ManagementEndpoint",
|
name: "Missing ManagementEndpoint",
|
||||||
inputConfig: PocketIdClientConfig{
|
inputConfig: PocketIdClientConfig{
|
||||||
APIToken: defaultTestConfig.APIToken,
|
APIToken: defaultTestConfig.APIToken,
|
||||||
ManagementEndpoint: "",
|
ManagementEndpoint: "",
|
||||||
},
|
},
|
||||||
assertErrFunc: require.Error,
|
assertErrFunc: require.Error,
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
assertErrFuncMessage: "should return error when field empty",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Missing APIToken",
|
name: "Missing APIToken",
|
||||||
inputConfig: PocketIdClientConfig{
|
inputConfig: PocketIdClientConfig{
|
||||||
APIToken: "",
|
APIToken: "",
|
||||||
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
|
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
|
||||||
},
|
},
|
||||||
assertErrFunc: require.Error,
|
assertErrFunc: require.Error,
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
assertErrFuncMessage: "should return error when field empty",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
|
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
|
||||||
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
|
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPocketID_GetUserDataByID(t *testing.T) {
|
func TestPocketID_GetUserDataByID(t *testing.T) {
|
||||||
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
|
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
|
||||||
|
|
||||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
mgr.httpClient = client
|
mgr.httpClient = client
|
||||||
|
|
||||||
md := AppMetadata{WTAccountID: "acc1"}
|
md := AppMetadata{WTAccountID: "acc1"}
|
||||||
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
|
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "u1", got.ID)
|
assert.Equal(t, "u1", got.ID)
|
||||||
assert.Equal(t, "user1@example.com", got.Email)
|
assert.Equal(t, "user1@example.com", got.Email)
|
||||||
assert.Equal(t, "User One", got.Name)
|
assert.Equal(t, "User One", got.Name)
|
||||||
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
|
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPocketID_GetAccount_WithPagination(t *testing.T) {
|
func TestPocketID_GetAccount_WithPagination(t *testing.T) {
|
||||||
// Single page response with two users
|
// Single page response with two users
|
||||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||||
|
|
||||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
mgr.httpClient = client
|
mgr.httpClient = client
|
||||||
|
|
||||||
users, err := mgr.GetAccount(context.Background(), "accX")
|
users, err := mgr.GetAccount(context.Background(), "accX")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, users, 2)
|
require.Len(t, users, 2)
|
||||||
assert.Equal(t, "u1", users[0].ID)
|
assert.Equal(t, "u1", users[0].ID)
|
||||||
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
|
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
|
||||||
assert.Equal(t, "u2", users[1].ID)
|
assert.Equal(t, "u2", users[1].ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
|
func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
|
||||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||||
|
|
||||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
mgr.httpClient = client
|
mgr.httpClient = client
|
||||||
|
|
||||||
accounts, err := mgr.GetAllAccounts(context.Background())
|
accounts, err := mgr.GetAllAccounts(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, accounts[UnsetAccountID], 2)
|
require.Len(t, accounts[UnsetAccountID], 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPocketID_CreateUser(t *testing.T) {
|
func TestPocketID_CreateUser(t *testing.T) {
|
||||||
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
|
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
|
||||||
|
|
||||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
mgr.httpClient = client
|
mgr.httpClient = client
|
||||||
|
|
||||||
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
|
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "newid", ud.ID)
|
assert.Equal(t, "newid", ud.ID)
|
||||||
assert.Equal(t, "new@example.com", ud.Email)
|
assert.Equal(t, "new@example.com", ud.Email)
|
||||||
assert.Equal(t, "New User", ud.Name)
|
assert.Equal(t, "New User", ud.Name)
|
||||||
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
|
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
|
||||||
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
|
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
|
||||||
assert.True(t, *ud.AppMetadata.WTPendingInvite)
|
assert.True(t, *ud.AppMetadata.WTPendingInvite)
|
||||||
}
|
}
|
||||||
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
|
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPocketID_InviteAndDeleteUser(t *testing.T) {
|
func TestPocketID_InviteAndDeleteUser(t *testing.T) {
|
||||||
// Same mock for both calls; returns OK with empty JSON
|
// Same mock for both calls; returns OK with empty JSON
|
||||||
client := &mockHTTPClient{code: 200, resBody: `{}`}
|
client := &mockHTTPClient{code: 200, resBody: `{}`}
|
||||||
|
|
||||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
mgr.httpClient = client
|
mgr.httpClient = client
|
||||||
|
|
||||||
err = mgr.InviteUserByID(context.Background(), "u1")
|
err = mgr.InviteUserByID(context.Background(), "u1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = mgr.DeleteUser(context.Background(), "u1")
|
err = mgr.DeleteUser(context.Background(), "u1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package mock_server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
@@ -12,7 +13,6 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||||
@@ -34,7 +34,7 @@ type MockAccountManager struct {
|
|||||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
||||||
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
||||||
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||||
@@ -84,7 +84,7 @@ type MockAccountManager struct {
|
|||||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
|
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||||
GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
|
||||||
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
||||||
GetDNSDomainFunc func(settings *types.Settings) string
|
GetDNSDomainFunc func(settings *types.Settings) string
|
||||||
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||||
@@ -119,7 +119,7 @@ type MockAccountManager struct {
|
|||||||
GetStoreFunc func() store.Store
|
GetStoreFunc func() store.Store
|
||||||
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error
|
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error
|
||||||
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
||||||
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
GetCurrentUserInfoFunc func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||||
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
||||||
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
|
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
|
||||||
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||||
@@ -469,7 +469,7 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUser mock implementation of GetUser from server.AccountManager interface
|
// GetUser mock implementation of GetUser from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) {
|
||||||
if am.GetUserFromUserAuthFunc != nil {
|
if am.GetUserFromUserAuthFunc != nil {
|
||||||
return am.GetUserFromUserAuthFunc(ctx, userAuth)
|
return am.GetUserFromUserAuthFunc(ctx, userAuth)
|
||||||
}
|
}
|
||||||
@@ -674,7 +674,7 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
|
||||||
if am.GetAccountIDFromUserAuthFunc != nil {
|
if am.GetAccountIDFromUserAuthFunc != nil {
|
||||||
return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)
|
return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)
|
||||||
}
|
}
|
||||||
@@ -936,7 +936,7 @@ func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, acco
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error {
|
||||||
return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented")
|
return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -968,7 +968,7 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
|
func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||||
if am.GetCurrentUserInfoFunc != nil {
|
if am.GetCurrentUserInfoFunc != nil {
|
||||||
return am.GetCurrentUserInfoFunc(ctx, userAuth)
|
return am.GetCurrentUserInfoFunc(ctx, userAuth)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
|
func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ import (
|
|||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
nbDomain "github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
nbDomain "github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) {
|
func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) {
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import (
|
|||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
type NetworkRouter struct {
|
type NetworkRouter struct {
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RouteFirewallRule a firewall rule applicable for a routed network.
|
// RouteFirewallRule a firewall rule applicable for a routed network.
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const channelBufferSize = 100
|
const channelBufferSize = 100
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -11,8 +13,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
@@ -175,9 +175,9 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t
|
|||||||
return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
|
return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUser looks up a user by provided nbContext.UserAuths.
|
// GetUser looks up a user by provided auth.UserAuths.
|
||||||
// Expects account to have been created already.
|
// Expects account to have been created already.
|
||||||
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
|
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -940,7 +940,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
|||||||
var peerIDs []string
|
var peerIDs []string
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
|
ctx = context.WithValue(ctx, nbcontext.PeerIDKey, peer.Key)
|
||||||
|
|
||||||
if peer.UserID == "" {
|
if peer.UserID == "" {
|
||||||
// we do not want to expire peers that are added via setup key
|
// we do not want to expire peers that are added via setup key
|
||||||
@@ -1171,7 +1171,7 @@ func validateUserInvite(invite *types.UserInfo) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetCurrentUserInfo retrieves the account's current user info and permissions
|
// GetCurrentUserInfo retrieves the account's current user info and permissions
|
||||||
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
|
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) {
|
||||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||||
|
|
||||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
|
|||||||
@@ -11,12 +11,12 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||||
"github.com/netbirdio/netbird/management/server/users"
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
@@ -966,7 +966,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
|||||||
permissionsManager: permissionsManager,
|
permissionsManager: permissionsManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
claims := nbcontext.UserAuth{
|
claims := auth.UserAuth{
|
||||||
UserId: mockUserID,
|
UserId: mockUserID,
|
||||||
AccountId: mockAccountID,
|
AccountId: mockAccountID,
|
||||||
}
|
}
|
||||||
@@ -1573,33 +1573,33 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
|||||||
|
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
userAuth nbcontext.UserAuth
|
userAuth auth.UserAuth
|
||||||
expectedErr error
|
expectedErr error
|
||||||
expectedResult *users.UserInfoWithPermissions
|
expectedResult *users.UserInfoWithPermissions
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "not found",
|
name: "not found",
|
||||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "not-found"},
|
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "not-found"},
|
||||||
expectedErr: status.NewUserNotFoundError("not-found"),
|
expectedErr: status.NewUserNotFoundError("not-found"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "not part of account",
|
name: "not part of account",
|
||||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account2Owner"},
|
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account2Owner"},
|
||||||
expectedErr: status.NewUserNotPartOfAccountError(),
|
expectedErr: status.NewUserNotPartOfAccountError(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "blocked",
|
name: "blocked",
|
||||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "blocked-user"},
|
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "blocked-user"},
|
||||||
expectedErr: status.NewUserBlockedError(),
|
expectedErr: status.NewUserBlockedError(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "service user",
|
name: "service user",
|
||||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "service-user"},
|
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "service-user"},
|
||||||
expectedErr: status.NewPermissionDeniedError(),
|
expectedErr: status.NewPermissionDeniedError(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "owner user",
|
name: "owner user",
|
||||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account1Owner"},
|
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account1Owner"},
|
||||||
expectedResult: &users.UserInfoWithPermissions{
|
expectedResult: &users.UserInfoWithPermissions{
|
||||||
UserInfo: &types.UserInfo{
|
UserInfo: &types.UserInfo{
|
||||||
ID: "account1Owner",
|
ID: "account1Owner",
|
||||||
@@ -1619,7 +1619,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "regular user",
|
name: "regular user",
|
||||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "regular-user"},
|
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "regular-user"},
|
||||||
expectedResult: &users.UserInfoWithPermissions{
|
expectedResult: &users.UserInfoWithPermissions{
|
||||||
UserInfo: &types.UserInfo{
|
UserInfo: &types.UserInfo{
|
||||||
ID: "regular-user",
|
ID: "regular-user",
|
||||||
@@ -1638,7 +1638,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "admin user",
|
name: "admin user",
|
||||||
userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "admin-user"},
|
userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "admin-user"},
|
||||||
expectedResult: &users.UserInfoWithPermissions{
|
expectedResult: &users.UserInfoWithPermissions{
|
||||||
UserInfo: &types.UserInfo{
|
UserInfo: &types.UserInfo{
|
||||||
ID: "admin-user",
|
ID: "admin-user",
|
||||||
@@ -1657,7 +1657,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "settings blocked regular user",
|
name: "settings blocked regular user",
|
||||||
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"},
|
userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"},
|
||||||
expectedResult: &users.UserInfoWithPermissions{
|
expectedResult: &users.UserInfoWithPermissions{
|
||||||
UserInfo: &types.UserInfo{
|
UserInfo: &types.UserInfo{
|
||||||
ID: "settings-blocked-user",
|
ID: "settings-blocked-user",
|
||||||
@@ -1678,7 +1678,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
name: "settings blocked regular user child account",
|
name: "settings blocked regular user child account",
|
||||||
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true},
|
userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true},
|
||||||
expectedResult: &users.UserInfoWithPermissions{
|
expectedResult: &users.UserInfoWithPermissions{
|
||||||
UserInfo: &types.UserInfo{
|
UserInfo: &types.UserInfo{
|
||||||
ID: "settings-blocked-user",
|
ID: "settings-blocked-user",
|
||||||
@@ -1698,7 +1698,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "settings blocked owner user",
|
name: "settings blocked owner user",
|
||||||
userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "account2Owner"},
|
userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "account2Owner"},
|
||||||
expectedResult: &users.UserInfoWithPermissions{
|
expectedResult: &users.UserInfoWithPermissions{
|
||||||
UserInfo: &types.UserInfo{
|
UserInfo: &types.UserInfo{
|
||||||
ID: "account2Owner",
|
ID: "account2Owner",
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -87,9 +87,10 @@ func (c ClaimsExtractor) audienceClaim(claimName string) string {
|
|||||||
return url
|
return url
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, error) {
|
// ToUserAuth extracts user authentication information from a JWT token
|
||||||
|
func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (auth.UserAuth, error) {
|
||||||
claims := token.Claims.(jwt.MapClaims)
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
userAuth := nbcontext.UserAuth{}
|
userAuth := auth.UserAuth{}
|
||||||
|
|
||||||
userID, ok := claims[c.userIDClaim].(string)
|
userID, ok := claims[c.userIDClaim].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -122,6 +123,7 @@ func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, erro
|
|||||||
return userAuth, nil
|
return userAuth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToGroups extracts group information from a JWT token
|
||||||
func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string {
|
func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string {
|
||||||
claims := token.Claims.(jwt.MapClaims)
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
userJWTGroups := make([]string, 0)
|
userJWTGroups := make([]string, 0)
|
||||||
28
shared/auth/user.go
Normal file
28
shared/auth/user.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserAuth struct {
|
||||||
|
// The account id the user is accessing
|
||||||
|
AccountId string
|
||||||
|
// The account domain
|
||||||
|
Domain string
|
||||||
|
// The account domain category, TBC values
|
||||||
|
DomainCategory string
|
||||||
|
// Indicates whether this user was invited, TBC logic
|
||||||
|
Invited bool
|
||||||
|
// Indicates whether this is a child account
|
||||||
|
IsChild bool
|
||||||
|
|
||||||
|
// The user id
|
||||||
|
UserId string
|
||||||
|
// Last login time for this user
|
||||||
|
LastLogin time.Time
|
||||||
|
// The Groups the user belongs to on this account
|
||||||
|
Groups []string
|
||||||
|
|
||||||
|
// Indicates whether this user has authenticated with a Personal Access Token
|
||||||
|
IsPAT bool
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user