Compare commits

...

3 Commits

Author SHA1 Message Date
pascal-fischer
6fec0c682e Merging full service user feature into main (#819)
Merging full feature branch into main.
Adding full support for service users including backend objects, persistence, verification and api endpoints.
2023-04-22 12:57:51 +02:00
Chinmay Pai
c2e90a2a97 feat: add support for custom device hostname (#789)
Configure via --hostname (or -n) flag in the `up` and `login` commands
---------

Signed-off-by: Chinmay D. Pai <chinmay.pai@zerodha.com>
2023-04-20 16:00:22 +02:00
Maycon Santos
118880b6f7 Send a status notification on offline peers change (#821)
Sum offline peers too
2023-04-20 15:59:07 +02:00
22 changed files with 981 additions and 177 deletions

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
) )
var loginCmd = &cobra.Command{ var loginCmd = &cobra.Command{
@@ -32,6 +33,11 @@ var loginCmd = &cobra.Command{
ctx := internal.CtxInitState(context.Background()) ctx := internal.CtxInitState(context.Background())
if hostName != "" {
// nolint
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
}
// workaround to run without service // workaround to run without service
if logFile == "console" { if logFile == "console" {
err = handleRebrand(cmd) err = handleRebrand(cmd)

View File

@@ -45,6 +45,7 @@ var (
managementURL string managementURL string
adminURL string adminURL string
setupKey string setupKey string
hostName string
preSharedKey string preSharedKey string
natExternalIPs []string natExternalIPs []string
customDNSAddress string customDNSAddress string
@@ -94,6 +95,7 @@ func init() {
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the the log will be output to stdout") rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the the log will be output to stdout")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVar(&preSharedKey, "preshared-key", "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.AddCommand(serviceCmd) rootCmd.AddCommand(serviceCmd)
rootCmd.AddCommand(upCmd) rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd) rootCmd.AddCommand(downCmd)

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -55,6 +56,11 @@ func upFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context()) ctx := internal.CtxInitState(cmd.Context())
if hostName != "" {
// nolint
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
}
if foregroundMode { if foregroundMode {
return runInForegroundMode(ctx, cmd) return runInForegroundMode(ctx, cmd)
} }

View File

@@ -78,6 +78,7 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
defer d.mux.Unlock() defer d.mux.Unlock()
d.offlinePeers = make([]State, len(replacement)) d.offlinePeers = make([]State, len(replacement))
copy(d.offlinePeers, replacement) copy(d.offlinePeers, replacement)
d.notifyPeerListChanged()
} }
// AddPeer adds peer to Daemon status map // AddPeer adds peer to Daemon status map
@@ -308,7 +309,7 @@ func (d *Status) onConnectionChanged() {
} }
func (d *Status) notifyPeerListChanged() { func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(len(d.peers)) d.notifier.peerListChanged(len(d.peers) + len(d.offlinePeers))
} }
func (d *Status) notifyAddressChanged() { func (d *Status) notifyAddressChanged() {

View File

@@ -43,6 +43,15 @@ func extractUserAgent(ctx context.Context) string {
return "" return ""
} }
// extractDeviceName extracts device name from context or returns the default system name
func extractDeviceName(ctx context.Context, defaultName string) string {
v, ok := ctx.Value(DeviceNameCtxKey).(string)
if !ok {
return defaultName
}
return v
}
// GetDesktopUIUserAgent returns the Desktop ui user agent // GetDesktopUIUserAgent returns the Desktop ui user agent
func GetDesktopUIUserAgent() string { func GetDesktopUIUserAgent() string {
return "netbird-desktop-ui/" + version.NetbirdVersion() return "netbird-desktop-ui/" + version.NetbirdVersion()

View File

@@ -24,21 +24,13 @@ func GetInfo(ctx context.Context) *Info {
} }
gio := &Info{Kernel: kernel, Core: osVersion(), Platform: "unknown", OS: "android", OSVersion: osVersion(), GoOS: runtime.GOOS, CPUs: runtime.NumCPU()} gio := &Info{Kernel: kernel, Core: osVersion(), Platform: "unknown", OS: "android", OSVersion: osVersion(), GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
gio.Hostname = extractDeviceName(ctx) gio.Hostname = extractDeviceName(ctx, "android")
gio.WiretrusteeVersion = version.NetbirdVersion() gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx) gio.UIVersion = extractUserAgent(ctx)
return gio return gio
} }
func extractDeviceName(ctx context.Context) string {
v, ok := ctx.Value(DeviceNameCtxKey).(string)
if !ok {
return "android"
}
return v
}
func uname() []string { func uname() []string {
res := run("/system/bin/uname", "-a") res := run("/system/bin/uname", "-a")
return strings.Split(res, " ") return strings.Split(res, " ")

View File

@@ -32,7 +32,8 @@ func GetInfo(ctx context.Context) *Info {
swVersion = []byte(release) swVersion = []byte(release)
} }
gio := &Info{Kernel: sysName, OSVersion: strings.TrimSpace(string(swVersion)), Core: release, Platform: machine, OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()} gio := &Info{Kernel: sysName, OSVersion: strings.TrimSpace(string(swVersion)), Core: release, Platform: machine, OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
gio.Hostname, _ = os.Hostname() systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.WiretrusteeVersion = version.NetbirdVersion() gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx) gio.UIVersion = extractUserAgent(ctx)

View File

@@ -24,7 +24,8 @@ func GetInfo(ctx context.Context) *Info {
osStr = strings.Replace(osStr, "\r\n", "", -1) osStr = strings.Replace(osStr, "\r\n", "", -1)
osInfo := strings.Split(osStr, " ") osInfo := strings.Split(osStr, " ")
gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: runtime.GOARCH, OS: osInfo[2], GoOS: runtime.GOOS, CPUs: runtime.NumCPU()} gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: runtime.GOARCH, OS: osInfo[2], GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
gio.Hostname, _ = os.Hostname() systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.WiretrusteeVersion = version.NetbirdVersion() gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx) gio.UIVersion = extractUserAgent(ctx)

View File

@@ -50,7 +50,8 @@ func GetInfo(ctx context.Context) *Info {
osName = osInfo[3] osName = osInfo[3]
} }
gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: osInfo[2], OS: osName, OSVersion: osVer, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()} gio := &Info{Kernel: osInfo[0], Core: osInfo[1], Platform: osInfo[2], OS: osName, OSVersion: osVer, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
gio.Hostname, _ = os.Hostname() systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.WiretrusteeVersion = version.NetbirdVersion() gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx) gio.UIVersion = extractUserAgent(ctx)

View File

@@ -24,3 +24,12 @@ func Test_UIVersion(t *testing.T) {
got := GetInfo(ctx) got := GetInfo(ctx)
assert.Equal(t, want, got.UIVersion) assert.Equal(t, want, got.UIVersion)
} }
func Test_CustomHostname(t *testing.T) {
// nolint
ctx := context.WithValue(context.Background(), DeviceNameCtxKey, "custom-host")
want := "custom-host"
got := GetInfo(ctx)
assert.Equal(t, want, got.Hostname)
}

View File

@@ -16,7 +16,8 @@ import (
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
ver := getOSVersion() ver := getOSVersion()
gio := &Info{Kernel: "windows", OSVersion: ver, Core: ver, Platform: "unknown", OS: "windows", GoOS: runtime.GOOS, CPUs: runtime.NumCPU()} gio := &Info{Kernel: "windows", OSVersion: ver, Core: ver, Platform: "unknown", OS: "windows", GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
gio.Hostname, _ = os.Hostname() systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.WiretrusteeVersion = version.NetbirdVersion() gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx) gio.UIVersion = extractUserAgent(ctx)

View File

@@ -49,15 +49,17 @@ type AccountManager interface {
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string) (*SetupKey, error) autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error) SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(accountID, userID string, key *UserInfo) (*UserInfo, error) CreateUser(accountID, executingUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(accountID, executingUserID string, targetUserID string) error
ListSetupKeys(accountID, userID string) ([]*SetupKey, error) ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID, userID string, update *User) (*UserInfo, error) SaveUser(accountID, userID string, update *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountByUserID(userID string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
MarkPATUsed(tokenID string) error MarkPATUsed(tokenID string) error
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdmin(userID string) (bool, error)
AccountExists(accountId string) (*bool, error) AccountExists(accountId string) (*bool, error)
GetPeerByKey(peerKey string) (*Peer, error) GetPeerByKey(peerKey string) (*Peer, error)
GetPeers(accountID, userID string) ([]*Peer, error) GetPeers(accountID, userID string) ([]*Peer, error)
@@ -177,6 +179,7 @@ type UserInfo struct {
Role string `json:"role"` Role string `json:"role"`
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
Status string `json:"-"` Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
} }
// getRoutesToSync returns the enabled routes for the peer ID and the routes // getRoutesToSync returns the enabled routes for the peer ID and the routes
@@ -1228,10 +1231,12 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId) return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
} }
if !user.IsServiceUser {
err = am.redeemInvite(account, claims.UserId) err = am.redeemInvite(account, claims.UserId)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
}
return account, user, nil return account, user, nil
} }

View File

@@ -87,6 +87,10 @@ const (
PersonalAccessTokenCreated PersonalAccessTokenCreated
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token // PersonalAccessTokenDeleted indicates that a user deleted a personal access token
PersonalAccessTokenDeleted PersonalAccessTokenDeleted
// ServiceUserCreated indicates that a user created a service user
ServiceUserCreated
// ServiceUserDeleted indicates that a user deleted a service user
ServiceUserDeleted
) )
const ( const (
@@ -176,6 +180,10 @@ const (
PersonalAccessTokenCreatedMessage string = "Personal access token created" PersonalAccessTokenCreatedMessage string = "Personal access token created"
// PersonalAccessTokenDeletedMessage is a human-readable text message of the PersonalAccessTokenDeleted activity // PersonalAccessTokenDeletedMessage is a human-readable text message of the PersonalAccessTokenDeleted activity
PersonalAccessTokenDeletedMessage string = "Personal access token deleted" PersonalAccessTokenDeletedMessage string = "Personal access token deleted"
// ServiceUserCreatedMessage is a human-readable text message of the ServiceUserCreated activity
ServiceUserCreatedMessage string = "Service user created"
// ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity
ServiceUserDeletedMessage string = "Service user deleted"
) )
// Activity that triggered an Event // Activity that triggered an Event
@@ -270,6 +278,10 @@ func (a Activity) Message() string {
return PersonalAccessTokenCreatedMessage return PersonalAccessTokenCreatedMessage
case PersonalAccessTokenDeleted: case PersonalAccessTokenDeleted:
return PersonalAccessTokenDeletedMessage return PersonalAccessTokenDeletedMessage
case ServiceUserCreated:
return ServiceUserCreatedMessage
case ServiceUserDeleted:
return ServiceUserDeletedMessage
default: default:
return "UNKNOWN_ACTIVITY" return "UNKNOWN_ACTIVITY"
} }
@@ -364,6 +376,10 @@ func (a Activity) StringCode() string {
return "personal.access.token.create" return "personal.access.token.create"
case PersonalAccessTokenDeleted: case PersonalAccessTokenDeleted:
return "personal.access.token.delete" return "personal.access.token.delete"
case ServiceUserCreated:
return "service.user.create"
case ServiceUserDeleted:
return "service.user.delete"
default: default:
return "UNKNOWN_ACTIVITY" return "UNKNOWN_ACTIVITY"
} }

View File

@@ -77,6 +77,10 @@ components:
description: Is true if authenticated user is the same as this user description: Is true if authenticated user is the same as this user
type: boolean type: boolean
readOnly: true readOnly: true
is_service_user:
description: Is true if this user is a service user
type: boolean
readOnly: true
required: required:
- id - id
- email - email
@@ -115,10 +119,13 @@ components:
type: array type: array
items: items:
type: string type: string
is_service_user:
description: Is true if this user is a service user
type: boolean
required: required:
- role - role
- auto_groups - auto_groups
- email - is_service_user
PeerMinimum: PeerMinimum:
type: object type: object
properties: properties:
@@ -825,6 +832,12 @@ paths:
tags: [ Users ] tags: [ Users ]
security: security:
- BearerAuth: [ ] - BearerAuth: [ ]
parameters:
- in: query
name: service_user
schema:
type: boolean
description: Filters users and returns either normal users or service users
responses: responses:
'200': '200':
description: A JSON array of Users description: A JSON array of Users
@@ -903,6 +916,30 @@ paths:
"$ref": "#/components/responses/forbidden" "$ref": "#/components/responses/forbidden"
'500': '500':
"$ref": "#/components/responses/internal_error" "$ref": "#/components/responses/internal_error"
delete:
summary: Delete a User
tags: [ Users ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The User ID
responses:
'200':
description: Delete status code
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/{userId}/tokens: /api/users/{userId}/tokens:
get: get:
summary: Returns a list of all tokens for a user summary: Returns a list of all tokens for a user

View File

@@ -677,6 +677,9 @@ type User struct {
// IsCurrent Is true if authenticated user is the same as this user // IsCurrent Is true if authenticated user is the same as this user
IsCurrent *bool `json:"is_current,omitempty"` IsCurrent *bool `json:"is_current,omitempty"`
// IsServiceUser Is true if this user is a service user
IsServiceUser *bool `json:"is_service_user,omitempty"`
// Name User's name from idp provider // Name User's name from idp provider
Name string `json:"name"` Name string `json:"name"`
@@ -696,7 +699,10 @@ type UserCreateRequest struct {
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
// Email User's Email to send invite to // Email User's Email to send invite to
Email string `json:"email"` Email *string `json:"email,omitempty"`
// IsServiceUser Is true if this user is a service user
IsServiceUser bool `json:"is_service_user"`
// Name User's full name // Name User's full name
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
@@ -787,6 +793,12 @@ type PutApiRulesIdJSONBody struct {
Sources *[]string `json:"sources,omitempty"` Sources *[]string `json:"sources,omitempty"`
} }
// GetApiUsersParams defines parameters for GetApiUsers.
type GetApiUsersParams struct {
// ServiceUser Filters users and returns either normal users or service users
ServiceUser *bool `form:"service_user,omitempty" json:"service_user,omitempty"`
}
// PutApiAccountsIdJSONRequestBody defines body for PutApiAccountsId for application/json ContentType. // PutApiAccountsIdJSONRequestBody defines body for PutApiAccountsId for application/json ContentType.
type PutApiAccountsIdJSONRequestBody PutApiAccountsIdJSONBody type PutApiAccountsIdJSONRequestBody PutApiAccountsIdJSONBody

View File

@@ -111,6 +111,7 @@ func (apiHandler *apiHandler) addUsersEndpoint() {
userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg) userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS") apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS") apiHandler.Router.HandleFunc("/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{id}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS") apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
} }

View File

@@ -12,7 +12,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
) )
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) type IsUserAdminFunc func(userID string) (bool, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct { type AccessControl struct {
@@ -37,7 +37,7 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims := a.claimsExtract.FromRequestContext(r) claims := a.claimsExtract.FromRequestContext(r)
ok, err := a.isUserAdmin(claims) ok, err := a.isUserAdmin(claims.UserId)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return return

View File

@@ -3,8 +3,10 @@ package http
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"strconv"
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
@@ -77,6 +79,36 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId)) util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
} }
// DeleteUser is a DELETE request to delete a user (only works for service users right now)
func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["id"]
if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
err = h.accountManager.DeleteUser(account.Id, user.Id, targetUserID)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
}
// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite). // CreateUser creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@@ -103,11 +135,17 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
return return
} }
email := ""
if req.Email != nil {
email = *req.Email
}
newUser, err := h.accountManager.CreateUser(account.Id, user.Id, &server.UserInfo{ newUser, err := h.accountManager.CreateUser(account.Id, user.Id, &server.UserInfo{
Email: req.Email, Email: email,
Name: *req.Name, Name: *req.Name,
Role: req.Role, Role: req.Role,
AutoGroups: req.AutoGroups, AutoGroups: req.AutoGroups,
IsServiceUser: req.IsServiceUser,
}) })
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@@ -137,9 +175,27 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
return return
} }
serviceUser := r.URL.Query().Get("service_user")
log.Debugf("UserCount: %v", len(data))
users := make([]*api.User, 0) users := make([]*api.User, 0)
for _, r := range data { for _, r := range data {
if serviceUser == "" {
users = append(users, toUserResponse(r, claims.UserId)) users = append(users, toUserResponse(r, claims.UserId))
continue
}
includeServiceUser, err := strconv.ParseBool(serviceUser)
log.Debugf("Should include service user: %v", includeServiceUser)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
return
}
log.Debugf("User %v is service user: %v", r.Name, r.IsServiceUser)
if includeServiceUser == r.IsServiceUser {
log.Debugf("Found service user: %v", r.Name)
users = append(users, toUserResponse(r, claims.UserId))
}
} }
util.WriteJSONObject(w, users) util.WriteJSONObject(w, users)
@@ -170,5 +226,6 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
AutoGroups: autoGroups, AutoGroups: autoGroups,
Status: userStatus, Status: userStatus,
IsCurrent: &isCurrent, IsCurrent: &isCurrent,
IsServiceUser: &user.IsServiceUser,
} }
} }

View File

@@ -1,52 +1,91 @@
package http package http
import ( import (
"bytes"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/magiconair/properties/assert" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
) )
func initUsers(user ...*server.User) *UsersHandler { const (
serviceUserID = "serviceUserID"
regularUserID = "regularUserID"
)
var usersTestAccount = &server.Account{
Id: existingAccountID,
Domain: domain,
Users: map[string]*server.User{
existingUserID: {
Id: existingUserID,
Role: "admin",
IsServiceUser: false,
},
regularUserID: {
Id: regularUserID,
Role: "user",
IsServiceUser: false,
},
serviceUserID: {
Id: serviceUserID,
Role: "user",
IsServiceUser: true,
},
},
}
func initUsersTestData() *UsersHandler {
return &UsersHandler{ return &UsersHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
users := make(map[string]*server.User, 0) return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
for _, u := range user {
users[u.Id] = u
}
return &server.Account{
Id: "12345",
Domain: "netbird.io",
Users: users,
}, users[claims.UserId], nil
}, },
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0) users := make([]*server.UserInfo, 0)
for _, v := range user { for _, v := range usersTestAccount.Users {
users = append(users, &server.UserInfo{ users = append(users, &server.UserInfo{
ID: v.Id, ID: v.Id,
Role: string(v.Role), Role: string(v.Role),
Name: "", Name: "",
Email: "", Email: "",
IsServiceUser: v.IsServiceUser,
}) })
} }
return users, nil return users, nil
}, },
CreateUserFunc: func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
if userID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
}
return key, nil
},
DeleteUserFunc: func(accountID string, executingUserID string, targetUserID string) error {
if targetUserID == notFoundUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
}
if !usersTestAccount.Users[targetUserID].IsServiceUser {
return status.Errorf(status.PermissionDenied, "user with ID %s is not a service user and can not be deleted", targetUserID)
}
return nil
},
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "1", UserId: existingUserID,
Domain: "hotmail.com", Domain: domain,
AccountId: "test_id", AccountId: existingAccountID,
} }
}), }),
), ),
@@ -54,8 +93,84 @@ func initUsers(user ...*server.User) *UsersHandler {
} }
func TestGetUsers(t *testing.T) { func TestGetUsers(t *testing.T) {
users := []*server.User{{Id: "1", Role: "admin"}, {Id: "2", Role: "user"}, {Id: "3", Role: "user"}} tt := []struct {
userHandler := initUsers(users...) name string
expectedStatus int
requestType string
requestPath string
expectedUserIDs []string
}{
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}},
{name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}},
{name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}},
}
userHandler := initUsersTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
userHandler.GetAllUsers(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
respBody := []*server.UserInfo{}
err = json.Unmarshal(content, &respBody)
if err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, len(respBody), len(tc.expectedUserIDs))
for _, v := range respBody {
assert.Contains(t, tc.expectedUserIDs, v.ID)
assert.Equal(t, v.ID, usersTestAccount.Users[v.ID].Id)
assert.Equal(t, v.Role, string(usersTestAccount.Users[v.ID].Role))
assert.Equal(t, v.IsServiceUser, usersTestAccount.Users[v.ID].IsServiceUser)
}
})
}
}
func TestCreateUser(t *testing.T) {
name := "name"
email := "email"
serviceUserToAdd := api.UserCreateRequest{
AutoGroups: []string{},
Email: nil,
IsServiceUser: true,
Name: &name,
Role: "admin",
}
serviceUserString, err := json.Marshal(serviceUserToAdd)
if err != nil {
t.Fatal(err)
}
regularUserToAdd := api.UserCreateRequest{
AutoGroups: []string{},
Email: &email,
IsServiceUser: true,
Name: &name,
Role: "admin",
}
regularUserString, err := json.Marshal(regularUserToAdd)
if err != nil {
t.Fatal(err)
}
tt := []struct { tt := []struct {
name string name string
@@ -65,40 +180,79 @@ func TestGetUsers(t *testing.T) {
requestBody io.Reader requestBody io.Reader
expectedResult []*server.User expectedResult []*server.User
}{ }{
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users}, {name: "CreateServiceUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(serviceUserString)},
// right now creation is blocked in AC middleware, will be refactored in the future
{name: "CreateRegularUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(regularUserString)},
} }
userHandler := initUsersTestData()
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
userHandler.GetAllUsers(rr, req) userHandler.CreateUser(rr, req)
res := rr.Result() res := rr.Result()
defer res.Body.Close() defer res.Body.Close()
if status := rr.Code; status != tc.expectedStatus { if status := rr.Code; status != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v", t.Fatalf("handler returned wrong status code: got %v want %v",
status, http.StatusOK) status, tc.expectedStatus)
} }
})
content, err := io.ReadAll(res.Body) }
if err != nil { }
t.Fatal(err)
} func TestDeleteUser(t *testing.T) {
tt := []struct {
respBody := []*server.UserInfo{} name string
err = json.Unmarshal(content, &respBody) expectedStatus int
if err != nil { expectedBody bool
t.Fatalf("Sent content is not in correct json format; %v", err) requestType string
} requestPath string
requestVars map[string]string
if tc.expectedResult != nil { requestBody io.Reader
for i, resp := range respBody { }{
assert.Equal(t, resp.ID, tc.expectedResult[i].Id) {
assert.Equal(t, string(resp.Role), string(tc.expectedResult[i].Role)) name: "Delete Regular User",
} requestType: http.MethodDelete,
requestPath: "/api/users/" + regularUserID,
requestVars: map[string]string{"id": regularUserID},
expectedStatus: http.StatusForbidden,
},
{
name: "Delete Service User",
requestType: http.MethodDelete,
requestPath: "/api/users/" + serviceUserID,
requestVars: map[string]string{"id": serviceUserID},
expectedStatus: http.StatusOK,
},
{
name: "Delete Not Existing User",
requestType: http.MethodDelete,
requestPath: "/api/users/" + notFoundUserID,
requestVars: map[string]string{"id": notFoundUserID},
expectedStatus: http.StatusNotFound,
},
}
userHandler := initUsersTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
rr := httptest.NewRecorder()
userHandler.DeleteUser(rr, req)
res := rr.Result()
defer res.Body.Close()
if status := rr.Code; status != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatus)
} }
}) })
} }

View File

@@ -15,12 +15,12 @@ import (
type MockAccountManager struct { type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
GetAccountByUserFunc func(userId string) (*server.Account, error) GetAccountByUserIDFunc func(userID string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdminFunc func(userID string) (bool, error)
AccountExistsFunc func(accountId string) (*bool, error) AccountExistsFunc func(accountId string) (*bool, error)
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error) GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error) GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
@@ -61,6 +61,7 @@ type MockAccountManager struct {
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, executingUserID string, targetUserID string) error
CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
@@ -112,12 +113,12 @@ func (am *MockAccountManager) GetOrCreateAccountByUser(
) )
} }
// GetAccountByUser mock implementation of GetAccountByUser from server.AccountManager interface // GetAccountByUserID mock implementation of GetAccountByUserID from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account, error) { func (am *MockAccountManager) GetAccountByUserID(userID string) (*server.Account, error) {
if am.GetAccountByUserFunc != nil { if am.GetAccountByUserIDFunc != nil {
return am.GetAccountByUserFunc(userId) return am.GetAccountByUserIDFunc(userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUserID is not implemented")
} }
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface // CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
@@ -394,9 +395,9 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst
} }
// IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface // IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface
func (am *MockAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { func (am *MockAccountManager) IsUserAdmin(userID string) (bool, error) {
if am.IsUserAdminFunc != nil { if am.IsUserAdminFunc != nil {
return am.IsUserAdminFunc(claims) return am.IsUserAdminFunc(userID)
} }
return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented") return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented")
} }
@@ -500,6 +501,14 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented") return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented")
} }
// DeleteUser mocks DeleteUser of the AccountManager interface
func (am *MockAccountManager) DeleteUser(accountID string, executingUserID string, targetUserID string) error {
if am.DeleteUserFunc != nil {
return am.DeleteUserFunc(accountID, executingUserID, targetUserID)
}
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
}
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface // GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if am.GetNameServerGroupFunc != nil { if am.GetNameServerGroupFunc != nil {

View File

@@ -4,11 +4,11 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
@@ -44,6 +44,9 @@ type UserRole string
type User struct { type User struct {
Id string Id string
Role UserRole Role UserRole
IsServiceUser bool
// ServiceUserName is only set if IsServiceUser is true
ServiceUserName string
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string AutoGroups []string
PATs map[string]*PersonalAccessToken PATs map[string]*PersonalAccessToken
@@ -65,10 +68,11 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
return &UserInfo{ return &UserInfo{
ID: u.Id, ID: u.Id,
Email: "", Email: "",
Name: "", Name: u.ServiceUserName,
Role: string(u.Role), Role: string(u.Role),
AutoGroups: u.AutoGroups, AutoGroups: u.AutoGroups,
Status: string(UserStatusActive), Status: string(UserStatusActive),
IsServiceUser: u.IsServiceUser,
}, nil }, nil
} }
if userData.ID != u.Id { if userData.ID != u.Id {
@@ -87,6 +91,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
Role: string(u.Role), Role: string(u.Role),
AutoGroups: autoGroups, AutoGroups: autoGroups,
Status: string(userStatus), Status: string(userStatus),
IsServiceUser: u.IsServiceUser,
}, nil }, nil
} }
@@ -104,31 +109,85 @@ func (u *User) Copy() *User {
Id: u.Id, Id: u.Id,
Role: u.Role, Role: u.Role,
AutoGroups: autoGroups, AutoGroups: autoGroups,
IsServiceUser: u.IsServiceUser,
ServiceUserName: u.ServiceUserName,
PATs: pats, PATs: pats,
} }
} }
// NewUser creates a new user // NewUser creates a new user
func NewUser(id string, role UserRole) *User { func NewUser(id string, role UserRole, isServiceUser bool, serviceUserName string, autoGroups []string) *User {
return &User{ return &User{
Id: id, Id: id,
Role: role, Role: role,
AutoGroups: []string{}, IsServiceUser: isServiceUser,
ServiceUserName: serviceUserName,
AutoGroups: autoGroups,
} }
} }
// NewRegularUser creates a new user with role UserRoleAdmin // NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(id string) *User { func NewRegularUser(id string) *User {
return NewUser(id, UserRoleUser) return NewUser(id, UserRoleUser, false, "", []string{})
} }
// NewAdminUser creates a new user with role UserRoleAdmin // NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(id string) *User { func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin) return NewUser(id, UserRoleAdmin, false, "", []string{})
}
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, executingUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
}
executingUser := account.Users[executingUserID]
if executingUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin {
return nil, status.Errorf(status.PermissionDenied, "only admins can create service users")
}
newUserID := uuid.New().String()
newUser := NewUser(newUserID, role, true, serviceUserName, autoGroups)
log.Debugf("New User: %v", newUser)
account.Users[newUserID] = newUser
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
meta := map[string]any{"name": newUser.ServiceUserName}
am.storeEvent(executingUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
return &UserInfo{
ID: newUser.Id,
Email: "",
Name: newUser.ServiceUserName,
Role: string(newUser.Role),
AutoGroups: newUser.AutoGroups,
Status: string(UserStatusActive),
IsServiceUser: true,
}, nil
} }
// CreateUser creates a new user under the given account. Effectively this is a user invite. // CreateUser creates a new user under the given account. Effectively this is a user invite.
func (am *DefaultAccountManager) CreateUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) { func (am *DefaultAccountManager) CreateUser(accountID, userID string, user *UserInfo) (*UserInfo, error) {
if user.IsServiceUser {
return am.createServiceUser(accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.AutoGroups)
}
return am.inviteNewUser(accountID, userID, user)
}
// inviteNewUser Invites a USer to a given account and creates reference in datastore
func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -193,8 +252,48 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, invite *Us
} }
// DeleteUser deletes a user from the given account.
func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, targetUserID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
targetUser := account.Users[targetUserID]
if targetUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
executingUser := account.Users[executingUserID]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin {
return status.Errorf(status.PermissionDenied, "only admins can delete service users")
}
if !targetUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "regular users can not be deleted")
}
meta := map[string]any{"name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
delete(account.Users, targetUserID)
err = am.Store.SaveAccount(account)
if err != nil {
return err
}
return nil
}
// CreatePAT creates a new PAT for the given user // CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -206,21 +305,26 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
} }
if executingUserID != targetUserId {
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
}
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser := account.Users[targetUserId] targetUser := account.Users[targetUserID]
if targetUser == nil { if targetUser == nil {
return nil, status.Errorf(status.NotFound, "targetUser not found") return nil, status.Errorf(status.NotFound, "targetUser not found")
} }
pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id) executingUser := account.Users[executingUserID]
if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
}
pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id)
if err != nil { if err != nil {
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
} }
@@ -232,8 +336,8 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str
return nil, status.Errorf(status.Internal, "failed to save account: %v", err) return nil, status.Errorf(status.Internal, "failed to save account: %v", err)
} }
meta := map[string]any{"name": pat.Name} meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserId, accountID, activity.PersonalAccessTokenCreated, meta) am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta)
return pat, nil return pat, nil
} }
@@ -243,21 +347,26 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
if executingUserID != targetUserID {
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
}
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(status.NotFound, "account not found: %s", err) return status.Errorf(status.NotFound, "account not found: %s", err)
} }
user := account.Users[targetUserID] targetUser := account.Users[targetUserID]
if user == nil { if targetUser == nil {
return status.Errorf(status.NotFound, "user not found") return status.Errorf(status.NotFound, "user not found")
} }
pat := user.PATs[tokenID] executingUser := account.Users[executingUserID]
if targetUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
}
pat := targetUser.PATs[tokenID]
if pat == nil { if pat == nil {
return status.Errorf(status.NotFound, "PAT not found") return status.Errorf(status.NotFound, "PAT not found")
} }
@@ -271,10 +380,10 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str
return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err) return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err)
} }
meta := map[string]any{"name": pat.Name} meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
delete(user.PATs, tokenID) delete(targetUser.PATs, tokenID)
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
@@ -288,21 +397,26 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
if executingUserID != targetUserID {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(status.NotFound, "account not found: %s", err) return nil, status.Errorf(status.NotFound, "account not found: %s", err)
} }
user := account.Users[targetUserID] targetUser := account.Users[targetUserID]
if user == nil { if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found") return nil, status.Errorf(status.NotFound, "user not found")
} }
pat := user.PATs[tokenID] executingUser := account.Users[executingUserID]
if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
}
pat := targetUser.PATs[tokenID]
if pat == nil { if pat == nil {
return nil, status.Errorf(status.NotFound, "PAT not found") return nil, status.Errorf(status.NotFound, "PAT not found")
} }
@@ -315,22 +429,27 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
if executingUserID != targetUserID {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(status.NotFound, "account not found: %s", err) return nil, status.Errorf(status.NotFound, "account not found: %s", err)
} }
user := account.Users[targetUserID] targetUser := account.Users[targetUserID]
if user == nil { if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found") return nil, status.Errorf(status.NotFound, "user not found")
} }
executingUser := account.Users[executingUserID]
if targetUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
var pats []*PersonalAccessToken var pats []*PersonalAccessToken
for _, pat := range user.PATs { for _, pat := range targetUser.PATs {
pats = append(pats, pat) pats = append(pats, pat)
} }
@@ -386,7 +505,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
group := account.GetGroup(g) group := account.GetGroup(g)
if group != nil { if group != nil {
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser, am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID}) map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else { } else {
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id) log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
} }
@@ -397,14 +516,14 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User
group := account.GetGroup(g) group := account.GetGroup(g)
if group != nil { if group != nil {
am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser, am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID}) map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else { } else {
log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id) log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
} }
} }
}() }()
if !isNil(am.idpManager) { if !isNil(am.idpManager) && !newUser.IsServiceUser {
userData, err := am.lookupUserInCache(newUser.Id, account) userData, err := am.lookupUserInCache(newUser.Id, account)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -454,14 +573,19 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
return account, nil return account, nil
} }
// IsUserAdmin flag for current user authenticated by JWT token // GetAccountByUserID returns an existing account for a given user id
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { func (am *DefaultAccountManager) GetAccountByUserID(userID string) (*Account, error) {
account, _, err := am.GetAccountFromToken(claims) return am.Store.GetAccountByUser(userID)
}
// IsUserAdmin looks up a user by his ID and returns true if he is an admin
func (am *DefaultAccountManager) IsUserAdmin(userID string) (bool, error) {
account, err := am.GetAccountByUserID(userID)
if err != nil { if err != nil {
return false, fmt.Errorf("get account: %v", err) return false, fmt.Errorf("get account: %v", err)
} }
user, ok := account.Users[claims.UserId] user, ok := account.Users[userID]
if !ok { if !ok {
return false, status.Errorf(status.NotFound, "user not found") return false, status.Errorf(status.NotFound, "user not found")
} }
@@ -486,8 +610,10 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
users := make(map[string]struct{}, len(account.Users)) users := make(map[string]struct{}, len(account.Users))
for _, user := range account.Users { for _, user := range account.Users {
if !user.IsServiceUser {
users[user.Id] = struct{}{} users[user.Id] = struct{}{}
} }
}
queriedUsers, err = am.lookupCache(users, accountID) queriedUsers, err = am.lookupCache(users, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -512,20 +638,44 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
return userInfos, nil return userInfos, nil
} }
for _, queriedUser := range queriedUsers { for _, localUser := range account.Users {
if !user.IsAdmin() && user.Id != queriedUser.ID { if !user.IsAdmin() && user.Id != localUser.Id {
// if user is not an admin then show only current user and do not show other users // if user is not an admin then show only current user and do not show other users
continue continue
} }
if localUser, contains := account.Users[queriedUser.ID]; contains {
info, err := localUser.toUserInfo(queriedUser) var info *UserInfo
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
info, err = localUser.toUserInfo(queriedUser)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userInfos = append(userInfos, info) } else {
name := ""
if localUser.IsServiceUser {
name = localUser.ServiceUserName
} }
info = &UserInfo{
ID: localUser.Id,
Email: "",
Name: name,
Role: string(localUser.Role),
AutoGroups: localUser.AutoGroups,
Status: string(UserStatusActive),
IsServiceUser: localUser.IsServiceUser,
}
}
userInfos = append(userInfos, info)
} }
return userInfos, nil return userInfos, nil
} }
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData {
if user.ID == userID {
return user, true
}
}
return nil, false
}

View File

@@ -1,8 +1,12 @@
package server package server
import ( import (
"fmt"
"reflect"
"testing" "testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
@@ -11,6 +15,9 @@ import (
const ( const (
mockAccountID = "accountID" mockAccountID = "accountID"
mockUserID = "userID" mockUserID = "userID"
mockServiceUserID = "serviceUserID"
mockRole = "user"
mockServiceUserName = "serviceUserName"
mockTargetUserId = "targetUserID" mockTargetUserId = "targetUserID"
mockTokenID1 = "tokenID1" mockTokenID1 = "tokenID1"
mockToken1 = "SoMeHaShEdToKeN1" mockToken1 = "SoMeHaShEdToKeN1"
@@ -41,6 +48,8 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
t.Fatalf("Error when adding PAT to user: %s", err) t.Fatalf("Error when adding PAT to user: %s", err)
} }
assert.Equal(t, pat.CreatedBy, mockUserID)
fileStore := am.Store.(*FileStore) fileStore := am.Store.(*FileStore)
tokenID := fileStore.HashedPAT2TokenID[pat.HashedToken] tokenID := fileStore.HashedPAT2TokenID[pat.HashedToken]
@@ -60,7 +69,10 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "") account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{
Id: mockTargetUserId,
IsServiceUser: false,
}
err := store.SaveAccount(account) err := store.SaveAccount(account)
if err != nil { if err != nil {
t.Fatalf("Error when saving account: %s", err) t.Fatalf("Error when saving account: %s", err)
@@ -75,6 +87,31 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
assert.Errorf(t, err, "Creating PAT for different user should thorw error") assert.Errorf(t, err, "Creating PAT for different user should thorw error")
} }
func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{
Id: mockTargetUserId,
IsServiceUser: true,
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
pat, err := am.CreatePAT(mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
if err != nil {
t.Fatalf("Error when adding PAT to user: %s", err)
}
assert.Equal(t, pat.CreatedBy, mockUserID)
}
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "") account := newAccountWithId(mockAccountID, mockUserID, "")
@@ -207,3 +244,300 @@ func TestUser_GetAllPATs(t *testing.T) {
assert.Equal(t, 2, len(pats)) assert.Equal(t, 2, len(pats))
} }
func TestUser_Copy(t *testing.T) {
// this is an imaginary case which will never be in DB this way
user := User{
Id: "userId",
Role: "role",
IsServiceUser: true,
ServiceUserName: "servicename",
AutoGroups: []string{"group1", "group2"},
PATs: map[string]*PersonalAccessToken{
"pat1": {
ID: "pat1",
Name: "First PAT",
HashedToken: "SoMeHaShEdToKeN",
ExpirationDate: time.Now().AddDate(0, 0, 7),
CreatedBy: "userId",
CreatedAt: time.Now(),
LastUsed: time.Now(),
},
},
}
err := validateStruct(user)
if err != nil {
t.Fatalf("Test needs update: dummy struct has not all fields set : %s", err)
}
copiedUser := user.Copy()
assert.True(t, cmp.Equal(user, *copiedUser))
}
// based on https://medium.com/@anajankow/fast-check-if-all-struct-fields-are-set-in-golang-bba1917213d2
func validateStruct(s interface{}) (err error) {
structType := reflect.TypeOf(s)
structVal := reflect.ValueOf(s)
fieldNum := structVal.NumField()
for i := 0; i < fieldNum; i++ {
field := structVal.Field(i)
fieldName := structType.Field(i).Name
isSet := field.IsValid() && !field.IsZero()
if !isSet {
err = fmt.Errorf("%v%s in not set; ", err, fieldName)
}
}
return err
}
func TestUser_CreateServiceUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
user, err := am.createServiceUser(mockAccountID, mockUserID, mockRole, mockServiceUserName, []string{"group1", "group2"})
if err != nil {
t.Fatalf("Error when creating service user: %s", err)
}
assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
assert.NotNil(t, store.Accounts[mockAccountID].Users[user.ID])
assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser)
assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName)
assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role)
assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups)
assert.Equal(t, map[string]*PersonalAccessToken{}, store.Accounts[mockAccountID].Users[user.ID].PATs)
assert.Zero(t, user.Email)
assert.True(t, user.IsServiceUser)
assert.Equal(t, "active", user.Status)
}
func TestUser_CreateUser_ServiceUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
user, err := am.CreateUser(mockAccountID, mockUserID, &UserInfo{
Name: mockServiceUserName,
Role: mockRole,
IsServiceUser: true,
AutoGroups: []string{"group1", "group2"},
})
if err != nil {
t.Fatalf("Error when creating user: %s", err)
}
assert.True(t, user.IsServiceUser)
assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser)
assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName)
assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role)
assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups)
assert.Equal(t, mockServiceUserName, user.Name)
assert.Equal(t, mockRole, user.Role)
assert.Equal(t, []string{"group1", "group2"}, user.AutoGroups)
assert.Equal(t, "active", user.Status)
}
func TestUser_CreateUser_RegularUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
_, err = am.CreateUser(mockAccountID, mockUserID, &UserInfo{
Name: mockServiceUserName,
Role: mockRole,
IsServiceUser: false,
AutoGroups: []string{"group1", "group2"},
})
assert.Errorf(t, err, "Not configured IDP will throw error but right path used")
}
func TestUser_DeleteUser_ServiceUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{
Id: mockServiceUserID,
IsServiceUser: true,
ServiceUserName: mockServiceUserName,
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
err = am.DeleteUser(mockAccountID, mockUserID, mockServiceUserID)
if err != nil {
t.Fatalf("Error when deleting user: %s", err)
}
assert.Equal(t, 1, len(store.Accounts[mockAccountID].Users))
assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID])
}
func TestUser_DeleteUser_regularUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
err = am.DeleteUser(mockAccountID, mockUserID, mockUserID)
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
}
func TestUser_IsUserAdmin_ForAdmin(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
ok, err := am.IsUserAdmin(mockUserID)
if err != nil {
t.Fatalf("Error when checking user role: %s", err)
}
assert.True(t, ok)
}
func TestUser_IsUserAdmin_ForUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
Id: mockUserID,
Role: "user",
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
ok, err := am.IsUserAdmin(mockUserID)
if err != nil {
t.Fatalf("Error when checking user role: %s", err)
}
assert.False(t, ok)
}
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{
Id: mockServiceUserID,
Role: "user",
IsServiceUser: true,
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
users, err := am.GetUsersFromAccount(mockAccountID, mockUserID)
if err != nil {
t.Fatalf("Error when getting users from account: %s", err)
}
assert.Equal(t, 2, len(users))
}
func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{
Id: mockServiceUserID,
Role: "user",
IsServiceUser: true,
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
users, err := am.GetUsersFromAccount(mockAccountID, mockServiceUserID)
if err != nil {
t.Fatalf("Error when getting users from account: %s", err)
}
assert.Equal(t, 1, len(users))
assert.Equal(t, mockServiceUserID, users[0].ID)
}