diff --git a/go.mod b/go.mod index 198253ed2..fc63e1b95 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/rs/xid v1.3.0 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/stretchr/testify v1.7.0 + github.com/wiretrustee/wiretrustee v0.5.1 ) require ( diff --git a/go.sum b/go.sum index d2c5b6903..2cb950995 100644 --- a/go.sum +++ b/go.sum @@ -497,6 +497,8 @@ github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7Zo github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb h1:CU1/+CEeCPvYXgfAyqTJXSQSf6hW3wsWM6Dfz6HkHEQ= github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb/go.mod h1:XT1Nrb4OxbVFPffbQMbq4PaeEkpRLVzdphh3fjrw7DY= +github.com/wiretrustee/wiretrustee v0.5.1 h1:mCuGOjsys5KgiLcRA38WYSwtyQ6YBPOZYquFz6EvTvk= +github.com/wiretrustee/wiretrustee v0.5.1/go.mod h1:y6sXInx4Fur1pTfiZUNI+S73AEj1nvIiM1OJzCCK7z0= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/management/server/account.go b/management/server/account.go index 3c58a7340..d1b8ff587 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1,6 +1,7 @@ package server import ( + "reflect" "strings" "sync" @@ -42,6 +43,7 @@ type AccountManager interface { GetPeerByIP(accountId string, peerIP string) (*Peer, error) GetNetworkMap(peerKey string) (*NetworkMap, error) AddPeer(setupKey string, peer *Peer) (*Peer, error) + GetUsersFromAccount(accountId string) ([]*UserInfo, error) GetGroup(accountId, groupID string) (*Group, error) SaveGroup(accountId string, group *Group) error DeleteGroup(accountId, groupID string) error @@ -74,6 +76,13 @@ type Account struct { Groups map[string]*Group } +type UserInfo struct { + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Role string `json:"role"` +} + // NewAccount creates a new Account with a generated ID and generated default setup keys func NewAccount(userId, domain string) *Account { accountId := xid.New().String() @@ -107,11 +116,7 @@ func (a *Account) Copy() *Account { } // NewManager creates a new DefaultAccountManager with a provided Store -func NewManager( - store Store, - peersUpdateManager *PeersUpdateManager, - idpManager idp.Manager, -) *DefaultAccountManager { +func NewManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager) *DefaultAccountManager { return &DefaultAccountManager{ Store: store, mux: sync.Mutex{}, @@ -222,9 +227,7 @@ func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, err // GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and // user id doesn't have an account associated with it, one account is created -func (am *DefaultAccountManager) GetAccountByUserOrAccountId( - userId, accountId, domain string, -) (*Account, error) { +func (am *DefaultAccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) { if accountId != "" { return am.GetAccountById(accountId) } else if userId != "" { @@ -242,9 +245,13 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountId( return nil, status.Errorf(codes.NotFound, "no valid user or account Id provided") } +func isNil(i idp.Manager) bool { + return i == nil || reflect.ValueOf(i).IsNil() +} + // updateIDPMetadata update user's app metadata in idp manager func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) error { - if am.idpManager != nil { + if !isNil(am.idpManager) { err := am.idpManager.UpdateUserAppMetadata(userId, idp.AppMetadata{WTAccountId: accountID}) if err != nil { return status.Errorf( @@ -257,6 +264,55 @@ func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) err return nil } +func mergeLocalAndQueryUser(queried idp.UserData, local User) *UserInfo { + return &UserInfo{ + ID: local.Id, + Email: queried.Email, + Name: queried.Name, + Role: string(local.Role), + } +} + +// GetUsersFromAccount performs a batched request for users from IDP by account id +func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserInfo, error) { + account, err := am.GetAccountById(accountID) + if err != nil { + return nil, err + } + + queriedUsers := make([]*idp.UserData, 0) + if !isNil(am.idpManager) { + queriedUsers, err = am.idpManager.GetBatchedUserData(accountID) + if err != nil { + return nil, err + } + } + + userInfo := make([]*UserInfo, 0) + + // in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo + if len(queriedUsers) == 0 { + for _, user := range account.Users { + userInfo = append(userInfo, &UserInfo{ + ID: user.Id, + Email: "", + Name: "", + Role: string(user.Role), + }) + } + return userInfo, nil + } + + for _, queriedUser := range queriedUsers { + if localUser, contains := account.Users[queriedUser.ID]; contains { + userInfo = append(userInfo, mergeLocalAndQueryUser(*queriedUser, *localUser)) + log.Debugf("Merged userinfo to send back; %v", userInfo) + } + } + + return userInfo, nil +} + // updateAccountDomainAttributes updates the account domain attributes and then, saves the account func (am *DefaultAccountManager) updateAccountDomainAttributes( account *Account, diff --git a/management/server/account_test.go b/management/server/account_test.go index 5f44970d2..89cb2ec65 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1,11 +1,13 @@ package server import ( - "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/stretchr/testify/require" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "net" "testing" + + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { @@ -190,6 +192,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { }) } } + func TestAccountManager_PrivateAccount(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -504,6 +507,39 @@ func TestAccountManager_DeletePeer(t *testing.T) { } +func TestGetUsersFromAccount(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + users := map[string]*User{"1": {Id: "1", Role: "admin"}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}} + accountId := "test_account_id" + + account, err := manager.AddAccount(accountId, users["1"].Id, "") + if err != nil { + t.Fatal(err) + } + + // add a user to the account + for _, user := range users { + account.Users[user.Id] = user + } + + userInfos, err := manager.GetUsersFromAccount(accountId) + if err != nil { + t.Fatal(err) + } + + for _, userInfo := range userInfos { + id := userInfo.ID + assert.Equal(t, userInfo.ID, users[id].Id) + assert.Equal(t, string(userInfo.Role), string(users[id].Role)) + assert.Equal(t, userInfo.Name, "") + assert.Equal(t, userInfo.Email, "") + } +} + func createManager(t *testing.T) (*DefaultAccountManager, error) { store, err := createStore(t) if err != nil { diff --git a/management/server/http/handler/users.go b/management/server/http/handler/users.go new file mode 100644 index 000000000..a156e1f0e --- /dev/null +++ b/management/server/http/handler/users.go @@ -0,0 +1,63 @@ +package handler + +import ( + "fmt" + "net/http" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/jwtclaims" +) + +type UserHandler struct { + accountManager server.AccountManager + authAudience string + jwtExtractor jwtclaims.ClaimsExtractor +} + +type UserResponse struct { + Email string + Role string +} + +func NewUserHandler(accountManager server.AccountManager, authAudience string) *UserHandler { + return &UserHandler{ + accountManager: accountManager, + authAudience: authAudience, + jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + } +} + +func (u *UserHandler) getAccountId(r *http.Request) (*server.Account, error) { + jwtClaims := u.jwtExtractor.ExtractClaimsFromRequestContext(r, u.authAudience) + + account, err := u.accountManager.GetAccountWithAuthorizationClaims(jwtClaims) + if err != nil { + return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err) + } + + return account, nil +} + +// GetUsers returns a list of users of the account this user belongs to. +// It also gathers additional user data (like email and name) from the IDP manager. +func (u *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "", http.StatusBadRequest) + } + + account, err := u.getAccountId(r) + if err != nil { + log.Error(err) + } + + data, err := u.accountManager.GetUsersFromAccount(account.Id) + if err != nil { + log.Error(err) + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + + writeJSONObject(w, data) +} diff --git a/management/server/http/handler/users_test.go b/management/server/http/handler/users_test.go new file mode 100644 index 000000000..7a42c0707 --- /dev/null +++ b/management/server/http/handler/users_test.go @@ -0,0 +1,106 @@ +package handler + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/magiconair/properties/assert" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/mock_server" +) + +func initUsers(user ...*server.User) *UserHandler { + return &UserHandler{ + accountManager: &mock_server.MockAccountManager{ + GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + users := make(map[string]*server.User, 0) + for _, u := range user { + users[u.Id] = u + } + return &server.Account{ + Id: "12345", + Domain: "netbird.io", + Users: users, + }, nil + }, + GetUsersFromAccountFunc: func(accountID string) ([]*server.UserInfo, error) { + users := make([]*server.UserInfo, 0) + for _, v := range user { + users = append(users, &server.UserInfo{ + ID: v.Id, + Role: string(v.Role), + Name: "", + Email: "", + }) + } + return users, nil + }, + }, + authAudience: "", + jwtExtractor: jwtclaims.ClaimsExtractor{ + ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { + return jwtclaims.AuthorizationClaims{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + } + }, + }, + } +} + +func TestGetUsers(t *testing.T) { + users := []*server.User{{Id: "1", Role: "admin"}, {Id: "2", Role: "user"}, {Id: "3", Role: "user"}} + userHandler := initUsers(users...) + + var tt = []struct { + name string + expectedStatus int + requestType string + requestPath string + requestBody io.Reader + expectedResult []*server.User + }{ + {name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users}, + {name: "WrongRequestMethod", requestType: http.MethodPost, requestPath: "/api/users/", expectedStatus: http.StatusBadRequest}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) + rr := httptest.NewRecorder() + + userHandler.GetUsers(rr, req) + + res := rr.Result() + defer res.Body.Close() + + if status := rr.Code; status != tc.expectedStatus { + t.Fatalf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + + respBody := []*server.UserInfo{} + err = json.Unmarshal(content, &respBody) + if err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + if tc.expectedResult != nil { + for i, resp := range respBody { + assert.Equal(t, resp.ID, tc.expectedResult[i].Id) + assert.Equal(t, string(resp.Role), string(tc.expectedResult[i].Role)) + } + } + }) + } +} diff --git a/management/server/http/server.go b/management/server/http/server.go index 83030847f..70e8c392a 100644 --- a/management/server/http/server.go +++ b/management/server/http/server.go @@ -102,6 +102,12 @@ func (s *Server) Start() error { r.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") + userHandler := handler.NewUserHandler(s.accountManager, s.config.AuthAudience) + r.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS") + + r.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("GET", "POST", "OPTIONS") + r.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).Methods("GET", "PUT", "OPTIONS") + r.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("POST", "OPTIONS") r.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey). Methods("GET", "PUT", "DELETE", "OPTIONS") diff --git a/management/server/idp/auth0.go b/management/server/idp/auth0.go index fe56dea99..a10d9220f 100644 --- a/management/server/idp/auth0.go +++ b/management/server/idp/auth0.go @@ -3,14 +3,17 @@ package idp import ( "encoding/json" "fmt" - "github.com/golang-jwt/jwt" - log "github.com/sirupsen/logrus" "io" "io/ioutil" "net/http" + "net/url" + "strconv" "strings" "sync" "time" + + "github.com/golang-jwt/jwt" + log "github.com/sirupsen/logrus" ) // Auth0Manager auth0 manager client instance @@ -94,7 +97,7 @@ func (c *Auth0Credentials) jwtStillValid() bool { // requestJWTToken performs request to get jwt token func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) { var res *http.Response - url := c.clientConfig.AuthIssuer + "/oauth/token" + reqURL := c.clientConfig.AuthIssuer + "/oauth/token" p, err := c.helper.Marshal(auth0JWTRequest(c.clientConfig)) if err != nil { @@ -102,7 +105,7 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) { } payload := strings.NewReader(string(p)) - req, err := http.NewRequest("POST", url, payload) + req, err := http.NewRequest("POST", reqURL, payload) if err != nil { return res, err } @@ -183,6 +186,133 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) { return c.jwtToken, nil } +func batchRequestUsersUrl(authIssuer, accountId string, page int) (string, url.Values, error) { + u, err := url.Parse(authIssuer + "/api/v2/users") + if err != nil { + return "", nil, err + } + q := u.Query() + q.Set("page", strconv.Itoa(page)) + q.Set("search_engine", "v3") + q.Set("q", "app_metadata.wt_account_id:"+accountId) + u.RawQuery = q.Encode() + + return u.String(), q, nil +} + +func requestByUserIdUrl(authIssuer, userId string) string { + return authIssuer + "/api/v2/users/" + userId +} + +// GetBatchedUserData requests users in batches from Auth0 +func (am *Auth0Manager) GetBatchedUserData(accountId string) ([]*UserData, error) { + jwtToken, err := am.credentials.Authenticate() + if err != nil { + return nil, err + } + + var list []*UserData + + // https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations + // auth0 limitation of 1000 users via this endpoint + for page := 0; page < 20; page++ { + reqURL, query, err := batchRequestUsersUrl(am.authIssuer, accountId, page) + if err != nil { + return nil, err + } + + req, err := http.NewRequest(http.MethodGet, reqURL, strings.NewReader(query.Encode())) + if err != nil { + return nil, err + } + + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + res, err := am.httpClient.Do(req) + if err != nil { + return nil, err + } + + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + var batch []UserData + err = json.Unmarshal(body, &batch) + if err != nil { + return nil, err + } + + log.Debugf("requested batch; %v", batch) + + err = res.Body.Close() + if err != nil { + return nil, err + } + + if res.StatusCode != 200 { + return nil, fmt.Errorf("unable to request UserData from auth0, statusCode %d", res.StatusCode) + } + + if len(batch) == 0 { + return list, nil + } + + for user := range batch { + list = append(list, &batch[user]) + } + } + + return list, nil +} + +// GetUserDataByID requests user data from auth0 via ID +func (am *Auth0Manager) GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) { + jwtToken, err := am.credentials.Authenticate() + if err != nil { + return nil, err + } + + reqURL := requestByUserIdUrl(am.authIssuer, userId) + req, err := http.NewRequest(http.MethodGet, reqURL, nil) + if err != nil { + return nil, err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + res, err := am.httpClient.Do(req) + if err != nil { + return nil, err + } + + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + var userData UserData + err = json.Unmarshal(body, &userData) + if err != nil { + return nil, err + } + + defer func() { + err = res.Body.Close() + if err != nil { + log.Errorf("error while closing update user app metadata response body: %v", err) + } + }() + + if res.StatusCode != 200 { + return nil, fmt.Errorf("unable to get UserData, statusCode %d", res.StatusCode) + } + + return &userData, nil +} + // UpdateUserAppMetadata updates user app metadata based on userId and metadata map func (am *Auth0Manager) UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error { @@ -191,7 +321,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userId string, appMetadata AppMeta return err } - url := am.authIssuer + "/api/v2/users/" + userId + reqURL := am.authIssuer + "/api/v2/users/" + userId data, err := am.helper.Marshal(appMetadata) if err != nil { @@ -202,7 +332,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userId string, appMetadata AppMeta payload := strings.NewReader(payloadString) - req, err := http.NewRequest("PATCH", url, payload) + req, err := http.NewRequest("PATCH", reqURL, payload) if err != nil { return err } diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 0b31b2357..6991b893a 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -10,6 +10,8 @@ import ( // Manager idp manager interface type Manager interface { UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error + GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) + GetBatchedUserData(accountId string) ([]*UserData, error) } // Config an idp configuration struct to be loaded from management server's config file @@ -34,6 +36,12 @@ type ManagerHelper interface { Unmarshal(data []byte, v interface{}) error } +type UserData struct { + Email string `json:"email"` + Name string `json:"name"` + ID string `json:"user_id"` +} + // AppMetadata user app metadata to associate with a profile type AppMetadata struct { // Wiretrustee account id to update in the IDP diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index b7aabe3c2..c62ced63c 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -33,6 +33,15 @@ type MockAccountManager struct { GroupAddPeerFunc func(accountID, groupID, peerKey string) error GroupDeletePeerFunc func(accountID, groupID, peerKey string) error GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error) + GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error) +} + +func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.UserInfo, error) { + if am.GetUsersFromAccountFunc != nil { + return am.GetUsersFromAccountFunc(accountID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount not implemented") + } func (am *MockAccountManager) GetOrCreateAccountByUser( diff --git a/management/server/user.go b/management/server/user.go index c7eb72d8e..55794b333 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -1,9 +1,10 @@ package server import ( + "strings" + "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "strings" ) const (