From de8608f99fb2c1a47ad6366c573518ef3b1ca0ad Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 21 Mar 2023 16:02:19 +0100 Subject: [PATCH 01/50] add rest endpoints and update openapi doc --- management/server/http/api/openapi.yml | 176 +++++++++++++++++++ management/server/http/api/types.gen.go | 33 ++++ management/server/http/handler.go | 9 + management/server/http/pat_handler.go | 187 +++++++++++++++++++++ management/server/http/pat_handler_test.go | 37 ++++ 5 files changed, 442 insertions(+) create mode 100644 management/server/http/pat_handler.go create mode 100644 management/server/http/pat_handler_test.go diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index b3d954a4d..3f742b850 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -6,6 +6,8 @@ info: tags: - name: Users description: Interact with and view information about users. + - name: Tokens + description: Interact with and view information about tokens. - name: Peers description: Interact with and view information about peers. - name: Setup Keys @@ -284,6 +286,53 @@ components: - revoked - auto_groups - usage_limit + PersonalAccessToken: + type: object + properties: + id: + description: ID of a token + type: string + description: + description: Description of the token + type: string +# hashed_token: +# description: Hashed representation of the token +# type: string + expiration_date: + description: Date the token expires + type: string + format: date-time + created_by: + description: User ID of the user who created the token + type: string + created_at: + description: Date the token was created + type: string + format: date-time + last_used: + description: Date the token was last used + type: string + format: date-time + required: + - id + - description +# - hashed_token + - expiration_date + - created_by + - created_at + - last_used + PersonalAccessTokenRequest: + type: object + properties: + description: + description: Description of the token + type: string + expires_in: + description: Expiration in days + type: integer + required: + - description + - expires_in GroupMinimum: type: object properties: @@ -848,6 +897,133 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/users/{userId}/tokens: + get: + summary: Returns a list of all tokens for a user + tags: [ Tokens ] + security: + - BearerAuth: [] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The User ID + responses: + '200': + description: A JSON Array of PersonalAccessTokens + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/PersonalAccessToken' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a new token + tags: [ Tokens ] + security: + - BearerAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The User ID + requestBody: + description: PersonalAccessToken create parameters + content: + application/json: + schema: + $ref: '#/components/schemas/PersonalAccessTokenRequest' + responses: + '200': + description: The token in plain text + content: + text/plain: + schema: + type: string + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/users/{userId}/tokens/{tokenId}: + get: + summary: Returns a specific token + tags: [ Tokens ] + security: + - BearerAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The User ID + - in: path + name: tokenId + required: true + schema: + type: string + description: The Token ID + responses: + '200': + description: A PersonalAccessTokens Object + content: + application/json: + schema: + $ref: '#/components/schemas/PersonalAccessToken' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a token + tags: [ Tokens ] + security: + - BearerAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The User ID + - in: path + name: tokenId + required: true + schema: + type: string + description: The Token ID + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/peers: get: summary: Returns a list of all peers diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 372ecd1a7..76c128d55 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -379,6 +379,36 @@ type PeerMinimum struct { Name string `json:"name"` } +// PersonalAccessToken defines model for PersonalAccessToken. +type PersonalAccessToken struct { + // CreatedAt Date the token was created + CreatedAt time.Time `json:"created_at"` + + // CreatedBy User ID of the user who created the token + CreatedBy string `json:"created_by"` + + // Description Description of the token + Description string `json:"description"` + + // ExpirationDate Date the token expires + ExpirationDate time.Time `json:"expiration_date"` + + // Id ID of a token + Id string `json:"id"` + + // LastUsed Date the token was last used + LastUsed time.Time `json:"last_used"` +} + +// PersonalAccessTokenRequest defines model for PersonalAccessTokenRequest. +type PersonalAccessTokenRequest struct { + // Description Description of the token + Description string `json:"description"` + + // ExpiresIn Expiration in days + ExpiresIn int `json:"expires_in"` +} + // Policy defines model for Policy. type Policy struct { // Description Policy friendly description @@ -808,3 +838,6 @@ type PostApiUsersJSONRequestBody = UserCreateRequest // PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType. type PutApiUsersIdJSONRequestBody = UserRequest + +// PostApiUsersUserIdTokensJSONRequestBody defines body for PostApiUsersUserIdTokens for application/json ContentType. +type PostApiUsersUserIdTokensJSONRequestBody = PersonalAccessTokenRequest diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 90f62e700..e2ed927a3 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -57,6 +57,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics api.addAccountsEndpoint() api.addPeersEndpoint() api.addUsersEndpoint() + api.addUsersTokensEndpoint() api.addSetupKeysEndpoint() api.addRulesEndpoint() api.addPoliciesEndpoint() @@ -110,6 +111,14 @@ func (apiHandler *apiHandler) addUsersEndpoint() { apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS") } +func (apiHandler *apiHandler) addUsersTokensEndpoint() { + tokenHandler := NewPATsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) + apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.GetAllTokens).Methods("GET", "OPTIONS") + apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.CreateToken).Methods("POST", "OPTIONS") + apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.GetToken).Methods("GET", "OPTIONS") + apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.DeleteToken).Methods("DELETE", "OPTIONS") +} + func (apiHandler *apiHandler) addSetupKeysEndpoint() { keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg) apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS") diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go new file mode 100644 index 000000000..8cdef141a --- /dev/null +++ b/management/server/http/pat_handler.go @@ -0,0 +1,187 @@ +package http + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/status" +) + +// PATHandler is the nameserver group handler of the account +type PATHandler struct { + accountManager server.AccountManager + claimsExtractor *jwtclaims.ClaimsExtractor +} + +func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATHandler { + return &PATHandler{ + accountManager: accountManager, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return + } + + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + vars := mux.Vars(r) + userID := vars["userId"] + if len(userID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + return + } + if userID != user.Id { + util.WriteErrorResponse("User not authorized to get tokens", http.StatusUnauthorized, w) + return + } + + var pats []*api.PersonalAccessToken + for _, pat := range account.Users[userID].PATs { + pats = append(pats, toPATResponse(pat)) + } + + util.WriteJSONObject(w, pats) +} + +func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return + } + + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + vars := mux.Vars(r) + userID := vars["userId"] + if len(userID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + return + } + if userID != user.Id { + util.WriteErrorResponse("User not authorized to get token", http.StatusUnauthorized, w) + return + } + + tokenID := vars["tokenId"] + if len(tokenID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w) + return + } + + pat := account.Users[userID].PATs[tokenID] + util.WriteJSONObject(w, toPATResponse(pat)) +} + +func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return + } + + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + vars := mux.Vars(r) + userID := vars["userId"] + if len(userID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + return + } + if userID != user.Id { + util.WriteErrorResponse("User not authorized to create token", http.StatusUnauthorized, w) + return + } + + var req api.PostApiUsersUserIdTokensJSONRequestBody + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + pat, plainToken, err := server.CreateNewPAT(req.Description, req.ExpiresIn, user.Id) + err = h.accountManager.AddPATToUser(account.Id, userID, pat) + if err != nil { + util.WriteError(err, w) + return + } + + util.WriteJSONObject(w, plainToken) +} + +func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return + } + + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + vars := mux.Vars(r) + userID := vars["userId"] + if len(userID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + return + } + if userID != user.Id { + util.WriteErrorResponse("User not authorized to delete token", http.StatusUnauthorized, w) + return + } + + tokenID := vars["tokenId"] + if len(tokenID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w) + return + } + + err = h.accountManager.DeletePAT(account.Id, userID, tokenID) + if err != nil { + util.WriteError(err, w) + return + } + + util.WriteJSONObject(w, "") +} + +func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { + return &api.PersonalAccessToken{ + CreatedAt: pat.CreatedAt, + CreatedBy: pat.CreatedBy, + Description: pat.Description, + ExpirationDate: pat.ExpirationDate, + Id: pat.ID, + LastUsed: pat.LastUsed, + } +} diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go new file mode 100644 index 000000000..3d32e7c30 --- /dev/null +++ b/management/server/http/pat_handler_test.go @@ -0,0 +1,37 @@ +package http + +import ( + "net/http" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/status" +) + +func initPATTestData() *PATHandler { + return &PATHandler{ + accountManager: &mock_server.MockAccountManager{ + + AddPATToUserFunc: func(accountID string, userID string, pat *server.PersonalAccessToken) error { + if nsGroupID == existingNSGroupID { + return baseExistingNSGroup.Copy(), nil + } + return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) + }, + + GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + return testingNSAccount, testingAccount.Users["test_user"], nil + }, + }, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { + return jwtclaims.AuthorizationClaims{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: testNSGroupAccountID, + } + }), + ), + } +} From 9e74f30d2f1da8f24b5c2b563128e7a4374d1443 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Mon, 27 Mar 2023 15:19:19 +0200 Subject: [PATCH 02/50] fix delete token parameter lookup --- management/server/user.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/user.go b/management/server/user.go index c3011c317..572843fff 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -228,7 +228,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, userID string, toke return status.Errorf(status.NotFound, "user not found") } - pat := user.PATs["tokenID"] + pat := user.PATs[tokenID] if pat == nil { return status.Errorf(status.NotFound, "PAT not found") } From c65a9341077cee4c2f98784ae103e4e8a7bb7917 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Mon, 27 Mar 2023 16:28:49 +0200 Subject: [PATCH 03/50] refactor to use name instead of description --- management/server/account_test.go | 2 +- management/server/http/api/openapi.yml | 16 ++++++--------- management/server/http/api/types.gen.go | 12 +++++------ management/server/http/pat_handler.go | 24 ++-------------------- management/server/personal_access_token.go | 6 +++--- 5 files changed, 18 insertions(+), 42 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index 5b4b1cc17..57b1cf3da 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1245,7 +1245,7 @@ func TestAccount_Copy(t *testing.T) { PATs: map[string]*PersonalAccessToken{ "pat1": { ID: "pat1", - Description: "First PAT", + Name: "First PAT", HashedToken: "SoMeHaShEdToKeN", ExpirationDate: time.Now().AddDate(0, 0, 7), CreatedBy: "user1", diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 3f742b850..c9a373411 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -292,12 +292,9 @@ components: id: description: ID of a token type: string - description: - description: Description of the token + name: + description: Name of the token type: string -# hashed_token: -# description: Hashed representation of the token -# type: string expiration_date: description: Date the token expires type: string @@ -315,8 +312,7 @@ components: format: date-time required: - id - - description -# - hashed_token + - name - expiration_date - created_by - created_at @@ -324,14 +320,14 @@ components: PersonalAccessTokenRequest: type: object properties: - description: - description: Description of the token + name: + description: Name of the token type: string expires_in: description: Expiration in days type: integer required: - - description + - name - expires_in GroupMinimum: type: object diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 76c128d55..4727e471e 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -387,9 +387,6 @@ type PersonalAccessToken struct { // CreatedBy User ID of the user who created the token CreatedBy string `json:"created_by"` - // Description Description of the token - Description string `json:"description"` - // ExpirationDate Date the token expires ExpirationDate time.Time `json:"expiration_date"` @@ -398,15 +395,18 @@ type PersonalAccessToken struct { // LastUsed Date the token was last used LastUsed time.Time `json:"last_used"` + + // Name Name of the token + Name string `json:"name"` } // PersonalAccessTokenRequest defines model for PersonalAccessTokenRequest. type PersonalAccessTokenRequest struct { - // Description Description of the token - Description string `json:"description"` - // ExpiresIn Expiration in days ExpiresIn int `json:"expires_in"` + + // Name Name of the token + Name string `json:"name"` } // Policy defines model for Policy. diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index 8cdef141a..04c1f369f 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -30,11 +30,6 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH } func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) - return - } - claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { @@ -62,11 +57,6 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { } func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) - return - } - claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { @@ -96,11 +86,6 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { } func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPut { - util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) - return - } - claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { @@ -126,7 +111,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { return } - pat, plainToken, err := server.CreateNewPAT(req.Description, req.ExpiresIn, user.Id) + pat, plainToken, err := server.CreateNewPAT(req.Name, req.ExpiresIn, user.Id) err = h.accountManager.AddPATToUser(account.Id, userID, pat) if err != nil { util.WriteError(err, w) @@ -137,11 +122,6 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { } func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete { - util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) - return - } - claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { @@ -179,7 +159,7 @@ func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { return &api.PersonalAccessToken{ CreatedAt: pat.CreatedAt, CreatedBy: pat.CreatedBy, - Description: pat.Description, + Name: pat.Name, ExpirationDate: pat.ExpirationDate, Id: pat.ID, LastUsed: pat.LastUsed, diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index 7416a9e0b..817605dce 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -25,7 +25,7 @@ const ( // PersonalAccessToken holds all information about a PAT including a hashed version of it for verification type PersonalAccessToken struct { ID string - Description string + Name string HashedToken string ExpirationDate time.Time // scope could be added in future @@ -36,7 +36,7 @@ type PersonalAccessToken struct { // CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User. // Additionally, it will return the token in plain text once, to give to the user and only save a hashed version -func CreateNewPAT(description string, expirationInDays int, createdBy string) (*PersonalAccessToken, string, error) { +func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessToken, string, error) { hashedToken, plainToken, err := generateNewToken() if err != nil { return nil, "", err @@ -44,7 +44,7 @@ func CreateNewPAT(description string, expirationInDays int, createdBy string) (* currentTime := time.Now().UTC() return &PersonalAccessToken{ ID: xid.New().String(), - Description: description, + Name: name, HashedToken: hashedToken, ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), CreatedBy: createdBy, From b66e984dddc384fc8a21c25da0e8c4d4b0a0449c Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Mon, 27 Mar 2023 17:28:24 +0200 Subject: [PATCH 04/50] set limits for expiration --- management/server/http/pat_handler.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index 04c1f369f..c7bcb92bc 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -111,6 +111,16 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { return } + if req.Name == "" { + util.WriteErrorResponse("name can't be empty", status.InvalidArgument, w) + return + } + + if req.ExpiresIn < 1 || req.ExpiresIn > 365 { + util.WriteErrorResponse("expiration has to be between 1 and 365", status.InvalidArgument, w) + return + } + pat, plainToken, err := server.CreateNewPAT(req.Name, req.ExpiresIn, user.Id) err = h.accountManager.AddPATToUser(account.Id, userID, pat) if err != nil { From 6a75ec4ab775cb14d68251a71be6830f042c36d9 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Mon, 27 Mar 2023 17:42:05 +0200 Subject: [PATCH 05/50] fix http error codes --- management/server/http/pat_handler.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index c7bcb92bc..034caa4e9 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -112,12 +112,12 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { } if req.Name == "" { - util.WriteErrorResponse("name can't be empty", status.InvalidArgument, w) + util.WriteErrorResponse("name can't be empty", http.StatusBadRequest, w) return } if req.ExpiresIn < 1 || req.ExpiresIn > 365 { - util.WriteErrorResponse("expiration has to be between 1 and 365", status.InvalidArgument, w) + util.WriteErrorResponse("expiration has to be between 1 and 365", http.StatusBadRequest, w) return } From 488d338ce83d98921964a0ea0e853688401919b4 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 28 Mar 2023 09:57:23 +0200 Subject: [PATCH 06/50] Refactor the authentication part of mobile exports (#759) Refactor the auth code into async calls for mobile framework --------- Co-authored-by: Maycon Santos --- client/android/client.go | 2 +- client/android/login.go | 51 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/client/android/client.go b/client/android/client.go index ac16316ed..5e3c0c85a 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -78,7 +78,7 @@ func (c *Client) Run(urlOpener URLOpener) error { c.ctxCancelLock.Unlock() auth := NewAuthWithConfig(ctx, cfg) - err = auth.Login(urlOpener) + err = auth.login(urlOpener) if err != nil { return err } diff --git a/client/android/login.go b/client/android/login.go index 4e2f1ab30..0c11c0cce 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -17,6 +17,18 @@ import ( "github.com/netbirdio/netbird/client/internal" ) +// SSOListener is async listener for mobile framework +type SSOListener interface { + OnSuccess(bool) + OnError(error) +} + +// ErrListener is async listener for mobile framework +type ErrListener interface { + OnSuccess() + OnError(error) +} + // URLOpener it is a callback interface. The Open function will be triggered if // the backend want to show an url for the user type URLOpener interface { @@ -59,7 +71,18 @@ func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth { // SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info. // If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO // is not supported and returns false without saving the configuration. For other errors return false. -func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { +func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) { + go func() { + sso, err := a.saveConfigIfSSOSupported() + if err != nil { + listener.OnError(err) + } else { + listener.OnSuccess(sso) + } + }() +} + +func (a *Auth) saveConfigIfSSOSupported() (bool, error) { supportsSSO := true err := a.withBackOff(a.ctx, func() (err error) { _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) @@ -83,7 +106,18 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { } // LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key. -func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { +func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) { + go func() { + err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName) + if err != nil { + resultListener.OnError(err) + } else { + resultListener.OnSuccess() + } + }() +} + +func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { //nolint ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) @@ -103,7 +137,18 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string } // Login try register the client on the server -func (a *Auth) Login(urlOpener URLOpener) error { +func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) { + go func() { + err := a.login(urlOpener) + if err != nil { + resultListener.OnError(err) + } else { + resultListener.OnSuccess() + } + }() +} + +func (a *Auth) login(urlOpener URLOpener) error { var needsLogin bool // check if we need to generate JWT token From 514403db370cedc82e02a17db98bdab73d0d3430 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 28 Mar 2023 14:47:15 +0200 Subject: [PATCH 07/50] use object instead of plain token for create response + handler test --- management/server/http/api/openapi.yml | 15 +- management/server/http/api/types.gen.go | 8 + management/server/http/pat_handler.go | 26 ++- management/server/http/pat_handler_test.go | 208 ++++++++++++++++++++- management/server/personal_access_token.go | 32 ++-- 5 files changed, 265 insertions(+), 24 deletions(-) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index c9a373411..2668198c4 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -317,6 +317,17 @@ components: - created_by - created_at - last_used + PersonalAccessTokenGenerated: + type: object + properties: + plain_token: + description: Plain text representation of the generated token + type: string + personal_access_token: + $ref: '#/components/schemas/PersonalAccessToken' + required: + - plain_token + - personal_access_token PersonalAccessTokenRequest: type: object properties: @@ -945,9 +956,9 @@ paths: '200': description: The token in plain text content: - text/plain: + application/json: schema: - type: string + $ref: '#/components/schemas/PersonalAccessTokenGenerated' '400': "$ref": "#/components/responses/bad_request" '401': diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 4727e471e..24abaf829 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -400,6 +400,14 @@ type PersonalAccessToken struct { Name string `json:"name"` } +// PersonalAccessTokenGenerated defines model for PersonalAccessTokenGenerated. +type PersonalAccessTokenGenerated struct { + PersonalAccessToken PersonalAccessToken `json:"personal_access_token"` + + // PlainToken Plain text representation of the generated token + PlainToken string `json:"plain_token"` +} + // PersonalAccessTokenRequest defines model for PersonalAccessTokenRequest. type PersonalAccessTokenRequest struct { // ExpiresIn Expiration in days diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index 034caa4e9..ab356b8c9 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -81,7 +81,18 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { return } - pat := account.Users[userID].PATs[tokenID] + user = account.Users[userID] + if user == nil { + util.WriteError(status.Errorf(status.NotFound, "user not found"), w) + return + } + + pat := user.PATs[tokenID] + if pat == nil { + util.WriteError(status.Errorf(status.NotFound, "PAT not found"), w) + return + } + util.WriteJSONObject(w, toPATResponse(pat)) } @@ -121,14 +132,14 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { return } - pat, plainToken, err := server.CreateNewPAT(req.Name, req.ExpiresIn, user.Id) - err = h.accountManager.AddPATToUser(account.Id, userID, pat) + pat, err := server.CreateNewPAT(req.Name, req.ExpiresIn, user.Id) + err = h.accountManager.AddPATToUser(account.Id, userID, &pat.PersonalAccessToken) if err != nil { util.WriteError(err, w) return } - util.WriteJSONObject(w, plainToken) + util.WriteJSONObject(w, toPATGeneratedResponse(pat)) } func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { @@ -175,3 +186,10 @@ func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { LastUsed: pat.LastUsed, } } + +func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated { + return &api.PersonalAccessTokenGenerated{ + PlainToken: pat.PlainToken, + PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken), + } +} diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index 3d32e7c30..c0a7185cb 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -1,37 +1,231 @@ package http import ( + "bytes" + "encoding/json" + "fmt" + "io" "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" ) +const ( + existingAccountID = "existingAccountID" + notFoundAccountID = "notFoundAccountID" + existingUserID = "existingUserID" + notFoundUserID = "notFoundUserID" + existingTokenID = "existingTokenID" + notFoundTokenID = "notFoundTokenID" + domain = "hotmail.com" +) + +var testAccount = &server.Account{ + Id: existingAccountID, + Domain: domain, + Users: map[string]*server.User{ + existingUserID: { + Id: existingUserID, + PATs: map[string]*server.PersonalAccessToken{ + existingTokenID: { + ID: existingTokenID, + Name: "My first token", + HashedToken: "someHash", + ExpirationDate: time.Now().AddDate(0, 0, 7), + CreatedBy: existingUserID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + "token2": { + ID: "token2", + Name: "My second token", + HashedToken: "someOtherHash", + ExpirationDate: time.Now().AddDate(0, 0, 7), + CreatedBy: existingUserID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + }, + }, + }, +} + func initPATTestData() *PATHandler { return &PATHandler{ accountManager: &mock_server.MockAccountManager{ - AddPATToUserFunc: func(accountID string, userID string, pat *server.PersonalAccessToken) error { - if nsGroupID == existingNSGroupID { - return baseExistingNSGroup.Copy(), nil + if accountID != existingAccountID { + return status.Errorf(status.NotFound, "account with ID %s not found", accountID) } - return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) + if userID != existingUserID { + return status.Errorf(status.NotFound, "user with ID %s not found", userID) + } + return nil }, GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingNSAccount, testingAccount.Users["test_user"], nil + return testAccount, testAccount.Users[existingUserID], nil + }, + DeletePATFunc: func(accountID string, userID string, tokenID string) error { + if accountID != existingAccountID { + return status.Errorf(status.NotFound, "account with ID %s not found", accountID) + } + if userID != existingUserID { + return status.Errorf(status.NotFound, "user with ID %s not found", userID) + } + if tokenID != existingTokenID { + return status.Errorf(status.NotFound, "token with ID %s not found", tokenID) + } + return nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", + UserId: existingUserID, + Domain: domain, AccountId: testNSGroupAccountID, } }), ), } } + +func TestTokenHandlers(t *testing.T) { + tt := []struct { + name string + expectedStatus int + expectedBody bool + requestType string + requestPath string + requestBody io.Reader + }{ + { + name: "Get All Tokens", + requestType: http.MethodGet, + requestPath: "/api/users/" + existingUserID + "/tokens", + expectedStatus: http.StatusOK, + expectedBody: true, + }, + { + name: "Get Existing Token", + requestType: http.MethodGet, + requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID, + expectedStatus: http.StatusOK, + expectedBody: true, + }, + { + name: "Get Not Existing Token", + requestType: http.MethodGet, + requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID, + expectedStatus: http.StatusNotFound, + }, + { + name: "Delete Existing Token", + requestType: http.MethodDelete, + requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID, + expectedStatus: http.StatusOK, + }, + { + name: "Delete Not Existing Token", + requestType: http.MethodDelete, + requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID, + expectedStatus: http.StatusNotFound, + }, + { + name: "POST OK", + requestType: http.MethodPost, + requestPath: "/api/users/" + existingUserID + "/tokens", + requestBody: bytes.NewBuffer( + []byte(fmt.Sprint("{\"name\":\"name\",\"expires_in\":7}"))), + expectedStatus: http.StatusOK, + expectedBody: true, + }, + } + + p := initPATTestData() + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + + router := mux.NewRouter() + router.HandleFunc("/api/users/{userId}/tokens", p.GetAllTokens).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.GetToken).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens", p.CreateToken).Methods("POST") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.DeleteToken).Methods("DELETE") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("I don't know what I expected; %v", err) + } + + if status := recorder.Code; status != tc.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v, content: %s", + status, tc.expectedStatus, string(content)) + return + } + + if !tc.expectedBody { + return + } + + switch tc.name { + case "POST OK": + got := &api.PersonalAccessTokenGenerated{} + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.NotEmpty(t, got.PlainToken) + assert.Equal(t, server.PATLength, len(got.PlainToken)) + case "Get All Tokens": + expectedTokens := []api.PersonalAccessToken{ + toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]), + toTokenResponse(*testAccount.Users[existingUserID].PATs["token2"]), + } + + var got []api.PersonalAccessToken + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.True(t, cmp.Equal(got, expectedTokens)) + case "Get Existing Token": + expectedToken := toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]) + got := &api.PersonalAccessToken{} + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.True(t, cmp.Equal(*got, expectedToken)) + } + + }) + } +} + +func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken { + return api.PersonalAccessToken{ + Id: serverToken.ID, + Name: serverToken.Name, + CreatedAt: serverToken.CreatedAt, + LastUsed: serverToken.LastUsed, + CreatedBy: serverToken.CreatedBy, + ExpirationDate: serverToken.ExpirationDate, + } +} diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index 817605dce..6eab840b8 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -34,23 +34,33 @@ type PersonalAccessToken struct { LastUsed time.Time } +// PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it +type PersonalAccessTokenGenerated struct { + PlainToken string + PersonalAccessToken +} + // CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User. // Additionally, it will return the token in plain text once, to give to the user and only save a hashed version -func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessToken, string, error) { +func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) { hashedToken, plainToken, err := generateNewToken() if err != nil { - return nil, "", err + return nil, err } currentTime := time.Now().UTC() - return &PersonalAccessToken{ - ID: xid.New().String(), - Name: name, - HashedToken: hashedToken, - ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), - CreatedBy: createdBy, - CreatedAt: currentTime, - LastUsed: currentTime, - }, plainToken, nil + return &PersonalAccessTokenGenerated{ + PersonalAccessToken: PersonalAccessToken{ + ID: xid.New().String(), + Name: name, + HashedToken: hashedToken, + ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), + CreatedBy: createdBy, + CreatedAt: currentTime, + LastUsed: currentTime, + }, + PlainToken: plainToken, + }, nil + } func generateNewToken() (string, string, error) { From 42ba0765c80ba173e0fad6a53d732f02fbfd7a21 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 28 Mar 2023 14:54:06 +0200 Subject: [PATCH 08/50] fix linter --- go.mod | 2 +- management/server/http/pat_handler.go | 5 +++++ management/server/http/pat_handler_test.go | 3 +-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index d6518467f..c74fee409 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( github.com/getlantern/systray v1.2.1 github.com/gliderlabs/ssh v0.3.4 github.com/godbus/dbus/v5 v5.1.0 + github.com/google/go-cmp v0.5.9 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 @@ -89,7 +90,6 @@ require ( github.com/go-stack/stack v1.8.0 // indirect github.com/gobwas/glob v0.2.3 // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect - github.com/google/go-cmp v0.5.9 // indirect github.com/google/gopacket v1.1.19 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index ab356b8c9..654ea44f8 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -133,6 +133,11 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { } pat, err := server.CreateNewPAT(req.Name, req.ExpiresIn, user.Id) + if err != nil { + util.WriteError(err, w) + return + } + err = h.accountManager.AddPATToUser(account.Id, userID, &pat.PersonalAccessToken) if err != nil { util.WriteError(err, w) diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index c0a7185cb..0d29f1ad4 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -3,7 +3,6 @@ package http import ( "bytes" "encoding/json" - "fmt" "io" "net/http" "net/http/httptest" @@ -148,7 +147,7 @@ func TestTokenHandlers(t *testing.T) { requestType: http.MethodPost, requestPath: "/api/users/" + existingUserID + "/tokens", requestBody: bytes.NewBuffer( - []byte(fmt.Sprint("{\"name\":\"name\",\"expires_in\":7}"))), + []byte("{\"name\":\"name\",\"expires_in\":7}")), expectedStatus: http.StatusOK, expectedBody: true, }, From 8ebd6ce9632f1aa533148ac3bec3f48bb6cfe419 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 29 Mar 2023 10:39:54 +0200 Subject: [PATCH 09/50] Add OnDisconnecting service callback (#767) Add OnDisconnecting service callback for mobile --- client/internal/connect.go | 1 + client/internal/peer/listener.go | 1 + client/internal/peer/notifier.go | 17 ++++++++++++++++- client/internal/peer/status.go | 5 +++++ 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index 3aca0bab9..47c63e6d0 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -164,6 +164,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, state.Set(StatusConnected) <-engineCtx.Done() + statusRecorder.ClientTeardown() backOff.Reset() diff --git a/client/internal/peer/listener.go b/client/internal/peer/listener.go index c8dc0fe70..c601fe534 100644 --- a/client/internal/peer/listener.go +++ b/client/internal/peer/listener.go @@ -5,6 +5,7 @@ type Listener interface { OnConnected() OnDisconnected() OnConnecting() + OnDisconnecting() OnAddressChanged(string, string) OnPeersListChanged(int) } diff --git a/client/internal/peer/notifier.go b/client/internal/peer/notifier.go index efc9e47ad..4e618d2f8 100644 --- a/client/internal/peer/notifier.go +++ b/client/internal/peer/notifier.go @@ -8,6 +8,7 @@ const ( stateDisconnected = iota stateConnected stateConnecting + stateDisconnecting ) type notifier struct { @@ -57,8 +58,12 @@ func (n *notifier) updateServerStates(mgmState bool, signalState bool) { } n.currentServerState = newState - n.lastNotification = n.calculateState(newState, n.currentClientState) + if n.lastNotification == stateDisconnecting { + return + } + + n.lastNotification = n.calculateState(newState, n.currentClientState) go n.notifyAll(n.lastNotification) } @@ -78,6 +83,14 @@ func (n *notifier) clientStop() { go n.notifyAll(n.lastNotification) } +func (n *notifier) clientTearDown() { + n.serverStateLock.Lock() + defer n.serverStateLock.Unlock() + n.currentClientState = false + n.lastNotification = stateDisconnecting + go n.notifyAll(n.lastNotification) +} + func (n *notifier) isServerStateChanged(newState bool) bool { return n.currentServerState != newState } @@ -99,6 +112,8 @@ func (n *notifier) notifyListener(l Listener, state int) { l.OnConnected() case stateConnecting: l.OnConnecting() + case stateDisconnecting: + l.OnDisconnecting() } } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 1ecdff301..62841d6fc 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -288,6 +288,11 @@ func (d *Status) ClientStop() { d.notifier.clientStop() } +// ClientTeardown will notify all listeners about the service is under teardown +func (d *Status) ClientTeardown() { + d.notifier.clientTearDown() +} + // AddConnectionListener add a listener to the notifier func (d *Status) AddConnectionListener(listener Listener) { d.notifier.addListener(listener) From ab0cf1b8aa1ebf1d9701f65ca503eb505cb7ca20 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 29 Mar 2023 10:40:31 +0200 Subject: [PATCH 10/50] Fix slice bounds out of range in msg decryption (#768) --- encryption/encryption.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/encryption/encryption.go b/encryption/encryption.go index 196c42106..1c6ec7806 100644 --- a/encryption/encryption.go +++ b/encryption/encryption.go @@ -3,10 +3,13 @@ package encryption import ( "crypto/rand" "fmt" + "golang.org/x/crypto/nacl/box" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +const nonceSize = 24 + // A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service // These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate) // Wireguard keys are used for encryption @@ -26,8 +29,11 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes. if err != nil { return nil, err } - copy(nonce[:], encryptedMsg[:24]) - opened, ok := box.Open(nil, encryptedMsg[24:], nonce, toByte32(peerPublicKey), toByte32(privateKey)) + if len(encryptedMsg) < nonceSize { + return nil, fmt.Errorf("invalid encrypted message lenght") + } + copy(nonce[:], encryptedMsg[:nonceSize]) + opened, ok := box.Open(nil, encryptedMsg[nonceSize:], nonce, toByte32(peerPublicKey), toByte32(privateKey)) if !ok { return nil, fmt.Errorf("failed to decrypt message from peer %s", peerPublicKey.String()) } @@ -36,8 +42,8 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes. } // Generates nonce of size 24 -func genNonce() (*[24]byte, error) { - var nonce [24]byte +func genNonce() (*[nonceSize]byte, error) { + var nonce [nonceSize]byte if _, err := rand.Read(nonce[:]); err != nil { return nil, err } From dfb7960cd474d3475c13c06e535f36f57d935a7d Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 29 Mar 2023 10:41:14 +0200 Subject: [PATCH 11/50] Fix pre-shared key query name for android configuration (#773) --- iface/ipc_parser_android.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iface/ipc_parser_android.go b/iface/ipc_parser_android.go index ef757a638..e1dd66856 100644 --- a/iface/ipc_parser_android.go +++ b/iface/ipc_parser_android.go @@ -33,7 +33,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { if p.PresharedKey != nil { preSharedHexKey := hex.EncodeToString(p.PresharedKey[:]) - sb.WriteString(fmt.Sprintf("public_key=%s\n", preSharedHexKey)) + sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey)) } if p.Remove { From 726ffb57405c08a45024ec72b04d57d44c60f0cd Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Wed, 29 Mar 2023 15:06:54 +0200 Subject: [PATCH 12/50] add comments for exported functions --- management/server/http/pat_handler.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index 654ea44f8..7a8175fbf 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -19,6 +19,7 @@ type PATHandler struct { claimsExtractor *jwtclaims.ClaimsExtractor } +// NewPATsHandler creates a new PATHandler HTTP handler func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATHandler { return &PATHandler{ accountManager: accountManager, @@ -29,6 +30,7 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH } } +// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) @@ -56,6 +58,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(w, pats) } +// GetToken is HTTP GET handler that returns a personal access token for the given user func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) @@ -96,6 +99,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(w, toPATResponse(pat)) } +// CreateToken is HTTP POST handler that creates a personal access token for the given user func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) @@ -147,6 +151,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(w, toPATGeneratedResponse(pat)) } +// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) account, user, err := h.accountManager.GetAccountFromToken(claims) From c5942e6b33362684e7baf7dd56ea7c0cc7e1c2f6 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Wed, 29 Mar 2023 15:21:53 +0200 Subject: [PATCH 13/50] store hashed token base64 encoded --- management/server/account.go | 34 +++++++++++++------ management/server/account_test.go | 9 +++-- management/server/file_store.go | 12 ++++--- management/server/personal_access_token.go | 4 ++- .../server/personal_access_token_test.go | 4 ++- management/server/user_test.go | 2 +- 6 files changed, 43 insertions(+), 22 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 1d4c10721..55a5a0299 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -3,6 +3,7 @@ package server import ( "context" "crypto/sha256" + b64 "encoding/base64" "fmt" "hash/crc32" "math/rand" @@ -54,7 +55,7 @@ type AccountManager interface { GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) - GetAccountFromPAT(pat string) (*Account, *User, error) + GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExists(accountId string) (*bool, error) GetPeerByKey(peerKey string) (*Peer, error) @@ -1120,44 +1121,55 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e } // GetAccountFromPAT returns Account and User associated with a personal access token -func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, error) { +func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) { if len(token) != PATLength { - return nil, nil, fmt.Errorf("token has wrong length") + return nil, nil, nil, fmt.Errorf("token has wrong length") } + log.Debugf("Token: %s", token) + prefix := token[:len(PATPrefix)] if prefix != PATPrefix { - return nil, nil, fmt.Errorf("token has wrong prefix") + return nil, nil, nil, fmt.Errorf("token has wrong prefix") } secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength] encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength] verificationChecksum, err := base62.Decode(encodedChecksum) if err != nil { - return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) + return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) } secretChecksum := crc32.ChecksumIEEE([]byte(secret)) if secretChecksum != verificationChecksum { - return nil, nil, fmt.Errorf("token checksum does not match") + return nil, nil, nil, fmt.Errorf("token checksum does not match") } hashedToken := sha256.Sum256([]byte(token)) - tokenID, err := am.Store.GetTokenIDByHashedToken(string(hashedToken[:])) + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) + tokenID, err := am.Store.GetTokenIDByHashedToken(encodedHashedToken) if err != nil { - return nil, nil, err + return nil, nil, nil, err } + log.Debugf("TokenID: %s", tokenID) user, err := am.Store.GetUserByTokenID(tokenID) + log.Debugf("User: %v", user) if err != nil { - return nil, nil, err + return nil, nil, nil, err } account, err := am.Store.GetAccountByUser(user.Id) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - return account, user, nil + + pat := user.PATs[tokenID] + if pat == nil { + return nil, nil, nil, fmt.Errorf("personal access token not found") + } + + return account, user, pat, nil } // GetAccountFromToken returns an account associated with this token diff --git a/management/server/account_test.go b/management/server/account_test.go index 57b1cf3da..8eea04362 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2,6 +2,7 @@ package server import ( "crypto/sha256" + b64 "encoding/base64" "fmt" "net" "reflect" @@ -465,12 +466,13 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) account.Users["someUser"] = &User{ Id: "someUser", PATs: map[string]*PersonalAccessToken{ - "pat1": { + "tokenId": { ID: "tokenId", - HashedToken: string(hashedToken[:]), + HashedToken: encodedHashedToken, }, }, } @@ -483,13 +485,14 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { Store: store, } - account, user, err := am.GetAccountFromPAT(token) + account, user, pat, err := am.GetAccountFromPAT(token) if err != nil { t.Fatalf("Error when getting Account from PAT: %s", err) } assert.Equal(t, "account_id", account.Id) assert.Equal(t, "someUser", user.Id) + assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat) } func TestAccountManager_PrivateAccount(t *testing.T) { diff --git a/management/server/file_store.go b/management/server/file_store.go index 4f8092cfb..b09dccf81 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -112,7 +112,7 @@ func restore(file string) (*FileStore, error) { store.UserID2AccountID[user.Id] = accountID for _, pat := range user.PATs { store.TokenID2UserID[pat.ID] = user.Id - store.HashedPAT2TokenID[pat.HashedToken[:]] = pat.ID + store.HashedPAT2TokenID[pat.HashedToken] = pat.ID } } @@ -268,7 +268,7 @@ func (s *FileStore) SaveAccount(account *Account) error { s.UserID2AccountID[user.Id] = accountCopy.Id for _, pat := range user.PATs { s.TokenID2UserID[pat.ID] = user.Id - s.HashedPAT2TokenID[pat.HashedToken[:]] = pat.ID + s.HashedPAT2TokenID[pat.HashedToken] = pat.ID } } @@ -349,11 +349,13 @@ func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) { s.mux.Lock() defer s.mux.Unlock() + log.Debugf("TOken still there: %v", token) + log.Debugf("TokenID2UserId %v", s.HashedPAT2TokenID) tokenID, ok := s.HashedPAT2TokenID[token] if !ok { return "", status.Errorf(status.NotFound, "tokenID not found: provided token doesn't exists") } - + log.Debugf("TokenID for token %s is %s", token, tokenID) return tokenID, nil } @@ -366,12 +368,12 @@ func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) { if !ok { return nil, status.Errorf(status.NotFound, "user not found: provided tokenID doesn't exists") } - + log.Debugf("UserID for tokenID %s is %s", tokenID, userID) accountID, ok := s.UserID2AccountID[userID] if !ok { return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") } - + log.Debugf("AccountID for userID %s is %s", userID, accountID) account, err := s.getAccount(accountID) if err != nil { return nil, err diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index 6eab840b8..a7c55018f 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -2,6 +2,7 @@ package server import ( "crypto/sha256" + b64 "encoding/base64" "fmt" "hash/crc32" "time" @@ -74,5 +75,6 @@ func generateNewToken() (string, string, error) { paddedChecksum := fmt.Sprintf("%06s", encodedChecksum) plainToken := PATPrefix + secret + paddedChecksum hashedToken := sha256.Sum256([]byte(plainToken)) - return string(hashedToken[:]), plainToken, nil + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) + return encodedHashedToken, plainToken, nil } diff --git a/management/server/personal_access_token_test.go b/management/server/personal_access_token_test.go index a4e02f750..03dd2ef4e 100644 --- a/management/server/personal_access_token_test.go +++ b/management/server/personal_access_token_test.go @@ -2,6 +2,7 @@ package server import ( "crypto/sha256" + b64 "encoding/base64" "hash/crc32" "strings" "testing" @@ -13,7 +14,8 @@ import ( func TestPAT_GenerateToken_Hashing(t *testing.T) { hashedToken, plainToken, _ := generateNewToken() expectedToken := sha256.Sum256([]byte(plainToken)) - assert.Equal(t, hashedToken, string(expectedToken[:])) + encodedExpectedToken := b64.StdEncoding.EncodeToString(expectedToken[:]) + assert.Equal(t, hashedToken, encodedExpectedToken) } func TestPAT_GenerateToken_Prefix(t *testing.T) { diff --git a/management/server/user_test.go b/management/server/user_test.go index 20f2ca4f1..ffbb71282 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -37,7 +37,7 @@ func TestUser_AddPATToUser(t *testing.T) { } fileStore := am.Store.(*FileStore) - tokenID := fileStore.HashedPAT2TokenID[mockToken[:]] + tokenID := fileStore.HashedPAT2TokenID[mockToken] if tokenID == "" { t.Fatal("GetTokenIDByHashedToken failed after adding PAT") From 0ca3d27a808b9ffb936e9a95a97d3f57faca9573 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Wed, 29 Mar 2023 15:25:44 +0200 Subject: [PATCH 14/50] update account mock --- management/server/mock_server/account_mock.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 2ae71d1a9..71870fd84 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -47,7 +47,7 @@ type MockAccountManager struct { DeletePolicyFunc func(accountID, policyID, userID string) error ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) - GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, error) + GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error) @@ -179,11 +179,11 @@ func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*ser } // GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface -func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, error) { +func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { if am.GetAccountFromPATFunc != nil { return am.GetAccountFromPATFunc(pat) } - return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") + return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") } // AddPATToUser mock implementation of AddPATToUser from server.AccountManager interface From 3bab7451429f63f631f8f49bbba66efda906b0e7 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Wed, 29 Mar 2023 17:46:09 +0200 Subject: [PATCH 15/50] last_used can be nil --- management/server/http/api/openapi.yml | 1 - management/server/http/api/types.gen.go | 2 +- management/server/http/pat_handler.go | 7 ++++++- management/server/personal_access_token.go | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 2668198c4..eaeb5693c 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -316,7 +316,6 @@ components: - expiration_date - created_by - created_at - - last_used PersonalAccessTokenGenerated: type: object properties: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 24abaf829..930a9df54 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -394,7 +394,7 @@ type PersonalAccessToken struct { Id string `json:"id"` // LastUsed Date the token was last used - LastUsed time.Time `json:"last_used"` + LastUsed *time.Time `json:"last_used,omitempty"` // Name Name of the token Name string `json:"name"` diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index 7a8175fbf..2f6cb1492 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -3,6 +3,7 @@ package http import ( "encoding/json" "net/http" + "time" "github.com/gorilla/mux" @@ -187,13 +188,17 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { } func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { + var lastUsed *time.Time + if !pat.LastUsed.IsZero() { + lastUsed = &pat.LastUsed + } return &api.PersonalAccessToken{ CreatedAt: pat.CreatedAt, CreatedBy: pat.CreatedBy, Name: pat.Name, ExpirationDate: pat.ExpirationDate, Id: pat.ID, - LastUsed: pat.LastUsed, + LastUsed: lastUsed, } } diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index a7c55018f..bdf34e9fd 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -57,7 +57,7 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), CreatedBy: createdBy, CreatedAt: currentTime, - LastUsed: currentTime, + LastUsed: time.Time{}, }, PlainToken: plainToken, }, nil From 4ec6d5d20ba9b34661c3abcffff49093c56daa14 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Wed, 29 Mar 2023 18:23:10 +0200 Subject: [PATCH 16/50] remove debug logs --- management/server/file_store.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/management/server/file_store.go b/management/server/file_store.go index b09dccf81..9b4a9f47f 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -349,13 +349,11 @@ func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) { s.mux.Lock() defer s.mux.Unlock() - log.Debugf("TOken still there: %v", token) - log.Debugf("TokenID2UserId %v", s.HashedPAT2TokenID) tokenID, ok := s.HashedPAT2TokenID[token] if !ok { return "", status.Errorf(status.NotFound, "tokenID not found: provided token doesn't exists") } - log.Debugf("TokenID for token %s is %s", token, tokenID) + return tokenID, nil } @@ -368,12 +366,12 @@ func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) { if !ok { return nil, status.Errorf(status.NotFound, "user not found: provided tokenID doesn't exists") } - log.Debugf("UserID for tokenID %s is %s", tokenID, userID) + accountID, ok := s.UserID2AccountID[userID] if !ok { return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") } - log.Debugf("AccountID for userID %s is %s", userID, accountID) + account, err := s.getAccount(accountID) if err != nil { return nil, err From 9746a7f61abe9fb9f0fce1682fed0766750a713e Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Wed, 29 Mar 2023 18:27:01 +0200 Subject: [PATCH 17/50] remove debug logs --- management/server/account.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 55a5a0299..3db6a2fe0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1126,8 +1126,6 @@ func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *Use return nil, nil, nil, fmt.Errorf("token has wrong length") } - log.Debugf("Token: %s", token) - prefix := token[:len(PATPrefix)] if prefix != PATPrefix { return nil, nil, nil, fmt.Errorf("token has wrong prefix") @@ -1151,10 +1149,8 @@ func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *Use if err != nil { return nil, nil, nil, err } - log.Debugf("TokenID: %s", tokenID) user, err := am.Store.GetUserByTokenID(tokenID) - log.Debugf("User: %v", user) if err != nil { return nil, nil, nil, err } From 03abdfa11211a346c3f2c837a881798786fdc9ec Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Wed, 29 Mar 2023 18:46:40 +0200 Subject: [PATCH 18/50] return empty object on all handlers instead of empty string --- management/server/http/groups_handler.go | 2 +- management/server/http/handler.go | 3 +++ management/server/http/nameservers_handler.go | 2 +- management/server/http/pat_handler.go | 2 +- management/server/http/peers_handler.go | 2 +- management/server/http/policies.go | 2 +- management/server/http/routes_handler.go | 2 +- management/server/http/rules_handler.go | 2 +- 8 files changed, 10 insertions(+), 7 deletions(-) diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index 9712d2e75..2464f47ef 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -300,7 +300,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetGroup returns a group diff --git a/management/server/http/handler.go b/management/server/http/handler.go index e2ed927a3..79028e6d2 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -25,6 +25,9 @@ type apiHandler struct { AuthCfg AuthCfg } +type emptyObject struct { +} + // APIHandler creates the Management service HTTP API handler registering all the available endpoints. func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { jwtMiddleware, err := middleware.NewJwtMiddleware( diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go index e7be617e2..5ad52a426 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/nameservers_handler.go @@ -243,7 +243,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetNameserverGroup handles a nameserver group Get request identified by ID diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index 2f6cb1492..d3e8b9ac5 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -184,7 +184,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 76c4f7502..7379277af 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -66,7 +66,7 @@ func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w htt util.WriteError(err, w) return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // HandlePeer handles all peer requests for GET, PUT and DELETE operations diff --git a/management/server/http/policies.go b/management/server/http/policies.go index 275992d14..a0fe3b1e2 100644 --- a/management/server/http/policies.go +++ b/management/server/http/policies.go @@ -225,7 +225,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetPolicy handles a group Get request identified by ID diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index b29e5c261..aaaaaa854 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -321,7 +321,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetRoute handles a route Get request identified by ID diff --git a/management/server/http/rules_handler.go b/management/server/http/rules_handler.go index 8925c3763..f8bb5f0cb 100644 --- a/management/server/http/rules_handler.go +++ b/management/server/http/rules_handler.go @@ -222,7 +222,7 @@ func (h *RulesHandler) DeleteRule(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetRule handles a group Get request identified by ID From ecc4f8a10d2466021a4862b6cc631a86c4425ea3 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Wed, 29 Mar 2023 19:13:01 +0200 Subject: [PATCH 19/50] fix Pat handler test --- management/server/http/pat_handler_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index 0d29f1ad4..7b83c5db8 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -223,7 +223,7 @@ func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessT Id: serverToken.ID, Name: serverToken.Name, CreatedAt: serverToken.CreatedAt, - LastUsed: serverToken.LastUsed, + LastUsed: &serverToken.LastUsed, CreatedBy: serverToken.CreatedBy, ExpirationDate: serverToken.ExpirationDate, } From db3a9f0aa2d294c4a010f70b60dc9226d7aaf7ee Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 10:54:09 +0200 Subject: [PATCH 20/50] refactor jwt token validation and add PAT to middleware auth --- management/cmd/management.go | 26 +- management/server/account.go | 28 ++ management/server/grpcserver.go | 25 +- management/server/http/handler.go | 20 +- .../server/http/middleware/auth_middleware.go | 170 ++++++++++++ .../http/middleware/auth_midleware_test.go | 1 + management/server/http/middleware/jwt.go | 249 ------------------ management/server/http/util/util.go | 8 +- .../handler.go => jwtclaims/jwtValidator.go} | 90 ++++++- management/server/mock_server/account_mock.go | 9 + 10 files changed, 341 insertions(+), 285 deletions(-) create mode 100644 management/server/http/middleware/auth_middleware.go create mode 100644 management/server/http/middleware/auth_midleware_test.go delete mode 100644 management/server/http/middleware/jwt.go rename management/server/{http/middleware/handler.go => jwtclaims/jwtValidator.go} (55%) diff --git a/management/cmd/management.go b/management/cmd/management.go index f3210d88e..620a89f16 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -19,25 +19,28 @@ import ( "github.com/google/uuid" "github.com/miekg/dns" - "github.com/netbirdio/netbird/management/server/activity/sqlite" - httpapi "github.com/netbirdio/netbird/management/server/http" - "github.com/netbirdio/netbird/management/server/metrics" - "github.com/netbirdio/netbird/management/server/telemetry" "golang.org/x/crypto/acme/autocert" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" + "github.com/netbirdio/netbird/management/server/activity/sqlite" + httpapi "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/metrics" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/util" - "github.com/netbirdio/netbird/encryption" - mgmtProto "github.com/netbirdio/netbird/management/proto" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" + + "github.com/netbirdio/netbird/encryption" + mgmtProto "github.com/netbirdio/netbird/management/proto" ) // ManagementLegacyPort is the port that was used before by the Management gRPC server. @@ -179,13 +182,22 @@ var ( tlsEnabled = true } + jwtValidator, err := jwtclaims.NewJWTValidator( + config.HttpConfig.AuthIssuer, + config.HttpConfig.AuthAudience, + config.HttpConfig.AuthKeysLocation, + ) + if err != nil { + return fmt.Errorf("failed creating JWT validator: %v", err) + } + httpAPIAuthCfg := httpapi.AuthCfg{ Issuer: config.HttpConfig.AuthIssuer, Audience: config.HttpConfig.AuthAudience, UserIDClaim: config.HttpConfig.AuthUserIDClaim, KeysLocation: config.HttpConfig.AuthKeysLocation, } - httpAPIHandler, err := httpapi.APIHandler(accountManager, appMetrics, httpAPIAuthCfg) + httpAPIHandler, err := httpapi.APIHandler(accountManager, *jwtValidator, appMetrics, httpAPIAuthCfg) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index 3db6a2fe0..e2cdf40b2 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -56,6 +56,7 @@ type AccountManager interface { GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) + MarkPATUsed(tokenID string) error IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExists(accountId string) (*bool, error) GetPeerByKey(peerKey string) (*Peer, error) @@ -1120,6 +1121,33 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e return nil } +func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { + unlock := am.Store.AcquireGlobalLock() + defer unlock() + + user, err := am.Store.GetUserByTokenID(tokenID) + log.Debugf("User: %v", user) + if err != nil { + return err + } + + account, err := am.Store.GetAccountByUser(user.Id) + if err != nil { + return err + } + + pat, ok := account.Users[user.Id].PATs[tokenID] + if !ok { + return fmt.Errorf("token not found") + } + + pat.LastUsed = time.Now() + + am.Store.SaveAccount(account) + + return nil +} + // GetAccountFromPAT returns Account and User associated with a personal access token func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) { if len(token) != PATLength { diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index fa0e49ed3..34f74cff9 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -3,24 +3,25 @@ package server import ( "context" "fmt" - pb "github.com/golang/protobuf/proto" //nolint "strings" "time" + pb "github.com/golang/protobuf/proto" // nolint + "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/golang/protobuf/ptypes/timestamp" - "github.com/netbirdio/netbird/encryption" - "github.com/netbirdio/netbird/management/proto" - internalStatus "github.com/netbirdio/netbird/management/server/status" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" gRPCPeer "google.golang.org/grpc/peer" "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/proto" + internalStatus "github.com/netbirdio/netbird/management/server/status" ) // GRPCServer an instance of a Management gRPC API server @@ -31,7 +32,7 @@ type GRPCServer struct { peersUpdateManager *PeersUpdateManager config *Config turnCredentialsManager TURNCredentialsManager - jwtMiddleware *middleware.JWTMiddleware + jwtValidator *jwtclaims.JWTValidator jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics } @@ -45,10 +46,10 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager return nil, err } - var jwtMiddleware *middleware.JWTMiddleware + var jwtValidator *jwtclaims.JWTValidator if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { - jwtMiddleware, err = middleware.NewJwtMiddleware( + jwtValidator, err = jwtclaims.NewJWTValidator( config.HttpConfig.AuthIssuer, config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation) @@ -86,7 +87,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager accountManager: accountManager, config: config, turnCredentialsManager: turnCredentialsManager, - jwtMiddleware: jwtMiddleware, + jwtValidator: jwtValidator, jwtClaimsExtractor: jwtClaimsExtractor, appMetrics: appMetrics, }, nil @@ -187,11 +188,11 @@ func (s *GRPCServer) cancelPeerRoutines(peer *Peer) { } func (s *GRPCServer) validateToken(jwtToken string) (string, error) { - if s.jwtMiddleware == nil { - return "", status.Error(codes.Internal, "no jwt middleware set") + if s.jwtValidator == nil { + return "", status.Error(codes.Internal, "no jwt validator set") } - token, err := s.jwtMiddleware.ValidateAndParse(jwtToken) + token, err := s.jwtValidator.ValidateAndParse(jwtToken) if err != nil { return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 79028e6d2..d8117a436 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -8,6 +8,7 @@ import ( s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware" + "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -25,18 +26,17 @@ type apiHandler struct { AuthCfg AuthCfg } +// EmptyObject is an empty struct used to return empty JSON object type emptyObject struct { } // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { - jwtMiddleware, err := middleware.NewJwtMiddleware( - authCfg.Issuer, - authCfg.Audience, - authCfg.KeysLocation) - if err != nil { - return nil, err - } +func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { + authMiddleware := middleware.NewAuthMiddleware( + accountManager.GetAccountFromPAT, + jwtValidator.ValidateAndParse, + accountManager.MarkPATUsed, + authCfg.Audience) corsMiddleware := cors.AllowAll() @@ -49,7 +49,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics metricsMiddleware := appMetrics.HTTPMiddleware() router := rootRouter.PathPrefix("/api").Subrouter() - router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler) + router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler) api := apiHandler{ Router: router, @@ -70,7 +70,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics api.addDNSSettingEndpoint() api.addEventsEndpoint() - err = api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { + err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { methods, err := route.GetMethods() if err != nil { return err diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go new file mode 100644 index 000000000..0b1478ef3 --- /dev/null +++ b/management/server/http/middleware/auth_middleware.go @@ -0,0 +1,170 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/status" +) + +type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) +type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error) +type MarkPATUsedFunc func(token string) error + +// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens +type AuthMiddleware struct { + getAccountFromPAT GetAccountFromPATFunc + validateAndParseToken ValidateAndParseTokenFunc + markPATUsed MarkPATUsedFunc + audience string +} + +const ( + userProperty = "user" +) + +// NewAuthMiddleware instance constructor +func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string) *AuthMiddleware { + return &AuthMiddleware{ + getAccountFromPAT: getAccountFromPAT, + validateAndParseToken: validateAndParseToken, + markPATUsed: markPATUsed, + audience: audience, + } +} + +// Handler method of the middleware which authenticates a user either by JWT claims or by PAT +func (a *AuthMiddleware) Handler(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := strings.Split(r.Header.Get("Authorization"), " ") + authType := auth[0] + switch strings.ToLower(authType) { + case "bearer": + err := a.CheckJWTFromRequest(w, r) + if err != nil { + log.Debugf("Error when validating JWT claims: %s", err.Error()) + util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) + return + } + h.ServeHTTP(w, r) + case "token": + err := a.CheckPATFromRequest(w, r) + if err != nil { + log.Debugf("Error when validating PAT claims: %s", err.Error()) + util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) + return + } + h.ServeHTTP(w, r) + default: + util.WriteError(status.Errorf(status.Unauthorized, "No valid authentication provided"), w) + return + } + }) +} + +// CheckJWTFromRequest checks if the JWT is valid +func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error { + + token, err := getTokenFromJWTRequest(r) + + // If an error occurs, call the error handler and return an error + if err != nil { + return fmt.Errorf("Error extracting token: %w", err) + } + + validatedToken, err := m.validateAndParseToken(token) + if err != nil { + return err + } + + if validatedToken == nil { + return nil + } + + // If we get here, everything worked and we can set the + // user property in context. + newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) // nolint + // Update the current request with the new context information. + *r = *newRequest + return nil +} + +// CheckPATFromRequest checks if the PAT is valid +func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Request) error { + token, err := getTokenFromPATRequest(r) + + // If an error occurs, call the error handler and return an error + if err != nil { + return fmt.Errorf("Error extracting token: %w", err) + } + + account, user, pat, err := m.getAccountFromPAT(token) + if err != nil { + util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) + return fmt.Errorf("invalid Token: %w", err) + } + if time.Now().After(pat.ExpirationDate) { + util.WriteError(status.Errorf(status.Unauthorized, "Token expired"), w) + return fmt.Errorf("token expired") + } + + err = m.markPATUsed(pat.ID) + if err != nil { + return err + } + + claimMaps := jwt.MapClaims{} + claimMaps[jwtclaims.UserIDClaim] = user.Id + claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id + claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain + claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) + newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) + // Update the current request with the new context information. + *r = *newRequest + return nil +} + +// getTokenFromJWTRequest is a "TokenExtractor" that takes a give request and extracts +// the JWT token from the Authorization header. +func getTokenFromJWTRequest(r *http.Request) (string, error) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "", nil // No error, just no token + } + + // TODO: Make this a bit more robust, parsing-wise + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", errors.New("Authorization header format must be Bearer {token}") + } + + return authHeaderParts[1], nil +} + +// getTokenFromPATRequest is a "TokenExtractor" that takes a give request and extracts +// the PAT token from the Authorization header. +func getTokenFromPATRequest(r *http.Request) (string, error) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "", nil // No error, just no token + } + + // TODO: Make this a bit more robust, parsing-wise + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" { + return "", errors.New("Authorization header format must be Token {token}") + } + + return authHeaderParts[1], nil +} diff --git a/management/server/http/middleware/auth_midleware_test.go b/management/server/http/middleware/auth_midleware_test.go new file mode 100644 index 000000000..c870d7c16 --- /dev/null +++ b/management/server/http/middleware/auth_midleware_test.go @@ -0,0 +1 @@ +package middleware diff --git a/management/server/http/middleware/jwt.go b/management/server/http/middleware/jwt.go deleted file mode 100644 index feb00ec86..000000000 --- a/management/server/http/middleware/jwt.go +++ /dev/null @@ -1,249 +0,0 @@ -package middleware - -import ( - "context" - "errors" - "fmt" - "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/status" - "log" - "net/http" - "strings" -) - -// A function called whenever an error is encountered -type errorHandler func(w http.ResponseWriter, r *http.Request, err string) - -// TokenExtractor is a function that takes a request as input and returns -// either a token or an error. An error should only be returned if an attempt -// to specify a token was found, but the information was somehow incorrectly -// formed. In the case where a token is simply not present, this should not -// be treated as an error. An empty string should be returned in that case. -type TokenExtractor func(r *http.Request) (string, error) - -// Options is a struct for specifying configuration options for the middleware. -type Options struct { - // The function that will return the Key to validate the JWT. - // It can be either a shared secret or a public key. - // Default value: nil - ValidationKeyGetter jwt.Keyfunc - // The name of the property in the request where the user information - // from the JWT will be stored. - // Default value: "user" - UserProperty string - // The function that will be called when there's an error validating the token - // Default value: - ErrorHandler errorHandler - // A boolean indicating if the credentials are required or not - // Default value: false - CredentialsOptional bool - // A function that extracts the token from the request - // Default: FromAuthHeader (i.e., from Authorization header as bearer token) - Extractor TokenExtractor - // Debug flag turns on debugging output - // Default: false - Debug bool - // When set, all requests with the OPTIONS method will use authentication - // Default: false - EnableAuthOnOptions bool - // When set, the middelware verifies that tokens are signed with the specific signing algorithm - // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks - // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ - // Default: nil - SigningMethod jwt.SigningMethod -} - -type JWTMiddleware struct { - Options Options -} - -func OnError(w http.ResponseWriter, r *http.Request, err string) { - util.WriteError(status.Errorf(status.Unauthorized, ""), w) -} - -// New constructs a new Secure instance with supplied options. -func New(options ...Options) *JWTMiddleware { - - var opts Options - if len(options) == 0 { - opts = Options{} - } else { - opts = options[0] - } - - if opts.UserProperty == "" { - opts.UserProperty = "user" - } - - if opts.ErrorHandler == nil { - opts.ErrorHandler = OnError - } - - if opts.Extractor == nil { - opts.Extractor = FromAuthHeader - } - - return &JWTMiddleware{ - Options: opts, - } -} - -func (m *JWTMiddleware) logf(format string, args ...interface{}) { - if m.Options.Debug { - log.Printf(format, args...) - } -} - -// HandlerWithNext is a special implementation for Negroni, but could be used elsewhere. -func (m *JWTMiddleware) HandlerWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - err := m.CheckJWTFromRequest(w, r) - - // If there was an error, do not call next. - if err == nil && next != nil { - next(w, r) - } -} - -func (m *JWTMiddleware) Handler(h http.Handler) http.Handler { - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Let secure process the request. If it returns an error, - // that indicates the request should not continue. - err := m.CheckJWTFromRequest(w, r) - - // If there was an error, do not continue. - if err != nil { - return - } - - h.ServeHTTP(w, r) - }) -} - -// FromAuthHeader is a "TokenExtractor" that takes a give request and extracts -// the JWT token from the Authorization header. -func FromAuthHeader(r *http.Request) (string, error) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", nil // No error, just no token - } - - // TODO: Make this a bit more robust, parsing-wise - authHeaderParts := strings.Fields(authHeader) - if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", errors.New("Authorization header format must be Bearer {token}") - } - - return authHeaderParts[1], nil -} - -// FromParameter returns a function that extracts the token from the specified -// query string parameter -func FromParameter(param string) TokenExtractor { - return func(r *http.Request) (string, error) { - return r.URL.Query().Get(param), nil - } -} - -// FromFirst returns a function that runs multiple token extractors and takes the -// first token it finds -func FromFirst(extractors ...TokenExtractor) TokenExtractor { - return func(r *http.Request) (string, error) { - for _, ex := range extractors { - token, err := ex(r) - if err != nil { - return "", err - } - if token != "" { - return token, nil - } - } - return "", nil - } -} - -func (m *JWTMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error { - if !m.Options.EnableAuthOnOptions { - if r.Method == "OPTIONS" { - return nil - } - } - - // Use the specified token extractor to extract a token from the request - token, err := m.Options.Extractor(r) - - // If debugging is turned on, log the outcome - if err != nil { - m.logf("Error extracting JWT: %v", err) - } else { - m.logf("Token extracted: %s", token) - } - - // If an error occurs, call the error handler and return an error - if err != nil { - m.Options.ErrorHandler(w, r, err.Error()) - return fmt.Errorf("Error extracting token: %w", err) - } - - validatedToken, err := m.ValidateAndParse(token) - if err != nil { - m.Options.ErrorHandler(w, r, err.Error()) - return err - } - - if validatedToken == nil { - return nil - } - - // If we get here, everything worked and we can set the - // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), m.Options.UserProperty, validatedToken)) //nolint - // Update the current request with the new context information. - *r = *newRequest - return nil -} - -// ValidateAndParse validates and parses a given access token against jwt standards and signing methods -func (m *JWTMiddleware) ValidateAndParse(token string) (*jwt.Token, error) { - // If the token is empty... - if token == "" { - // Check if it was required - if m.Options.CredentialsOptional { - m.logf("no credentials found (CredentialsOptional=true)") - // No error, just no token (and that is ok given that CredentialsOptional is true) - return nil, nil - } - - // If we get here, the required token is missing - errorMsg := "required authorization token not found" - m.logf(" Error: No credentials found (CredentialsOptional=false)") - return nil, fmt.Errorf(errorMsg) - } - - // Now parse the token - parsedToken, err := jwt.Parse(token, m.Options.ValidationKeyGetter) - - // Check if there was an error in parsing... - if err != nil { - m.logf("error parsing token: %v", err) - return nil, fmt.Errorf("Error parsing token: %w", err) - } - - if m.Options.SigningMethod != nil && m.Options.SigningMethod.Alg() != parsedToken.Header["alg"] { - errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s", - m.Options.SigningMethod.Alg(), - parsedToken.Header["alg"]) - m.logf("error validating token algorithm: %s", errorMsg) - return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg) - } - - // Check if the parsed token is valid... - if !parsedToken.Valid { - errorMsg := "token is invalid" - m.logf(errorMsg) - return nil, errors.New(errorMsg) - } - - return parsedToken, nil -} diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 0055511a2..c40daa1a3 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -4,10 +4,12 @@ import ( "encoding/json" "errors" "fmt" - "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" "net/http" "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/status" ) // WriteJSONObject simply writes object to the HTTP reponse in JSON format @@ -93,6 +95,8 @@ func WriteError(err error, w http.ResponseWriter) { httpStatus = http.StatusInternalServerError case status.InvalidArgument: httpStatus = http.StatusUnprocessableEntity + case status.Unauthorized: + httpStatus = http.StatusUnauthorized default: } msg = err.Error() diff --git a/management/server/http/middleware/handler.go b/management/server/jwtclaims/jwtValidator.go similarity index 55% rename from management/server/http/middleware/handler.go rename to management/server/jwtclaims/jwtValidator.go index c647506bc..d324c5ab3 100644 --- a/management/server/http/middleware/handler.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -1,4 +1,4 @@ -package middleware +package jwtclaims import ( "bytes" @@ -17,6 +17,32 @@ import ( log "github.com/sirupsen/logrus" ) +// Options is a struct for specifying configuration options for the middleware. +type Options struct { + // The function that will return the Key to validate the JWT. + // It can be either a shared secret or a public key. + // Default value: nil + ValidationKeyGetter jwt.Keyfunc + // The name of the property in the request where the user information + // from the JWT will be stored. + // Default value: "user" + UserProperty string + // The function that will be called when there's an error validating the token + // Default value: + CredentialsOptional bool + // A function that extracts the token from the request + // Default: FromAuthHeader (i.e., from Authorization header as bearer token) + Debug bool + // When set, all requests with the OPTIONS method will use authentication + // Default: false + EnableAuthOnOptions bool + // When set, the middelware verifies that tokens are signed with the specific signing algorithm + // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks + // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ + // Default: nil + SigningMethod jwt.SigningMethod +} + // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation type Jwks struct { Keys []JSONWebKey `json:"keys"` @@ -32,14 +58,17 @@ type JSONWebKey struct { X5c []string `json:"x5c"` } -// NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header -func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) { +type JWTValidator struct { + options Options +} + +func NewJWTValidator(issuer string, audience string, keysLocation string) (*JWTValidator, error) { keys, err := getPemKeys(keysLocation) if err != nil { return nil, err } - return New(Options{ + options := Options{ ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { // Verify 'aud' claim checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) @@ -62,7 +91,58 @@ func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWT }, SigningMethod: jwt.SigningMethodRS256, EnableAuthOnOptions: false, - }), nil + } + + if options.UserProperty == "" { + options.UserProperty = "user" + } + + return &JWTValidator{ + options: options, + }, nil +} + +func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { + // If the token is empty... + if token == "" { + // Check if it was required + if m.options.CredentialsOptional { + log.Debugf("no credentials found (CredentialsOptional=true)") + // No error, just no token (and that is ok given that CredentialsOptional is true) + return nil, nil + } + + // If we get here, the required token is missing + errorMsg := "required authorization token not found" + log.Debugf(" Error: No credentials found (CredentialsOptional=false)") + return nil, fmt.Errorf(errorMsg) + } + + // Now parse the token + parsedToken, err := jwt.Parse(token, m.options.ValidationKeyGetter) + + // Check if there was an error in parsing... + if err != nil { + log.Debugf("error parsing token: %v", err) + return nil, fmt.Errorf("Error parsing token: %w", err) + } + + if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] { + errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s", + m.options.SigningMethod.Alg(), + parsedToken.Header["alg"]) + log.Debugf("error validating token algorithm: %s", errorMsg) + return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg) + } + + // Check if the parsed token is valid... + if !parsedToken.Valid { + errorMsg := "token is invalid" + log.Debugf(errorMsg) + return nil, errors.New(errorMsg) + } + + return parsedToken, nil } func getPemKeys(keysLocation string) (*Jwks, error) { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 71870fd84..c91201344 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -48,6 +48,7 @@ type MockAccountManager struct { ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + MarkPATUsedFunc func(pat string) error UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error) @@ -186,6 +187,14 @@ func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *s return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") } +// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface +func (am *MockAccountManager) MarkPATUsed(pat string) error { + if am.MarkPATUsedFunc != nil { + return am.MarkPATUsedFunc(pat) + } + return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented") +} + // AddPATToUser mock implementation of AddPATToUser from server.AccountManager interface func (am *MockAccountManager) AddPATToUser(accountID string, userID string, pat *server.PersonalAccessToken) error { if am.AddPATToUserFunc != nil { From 5c1acdbf2fdd380c29e84e4810b33bedb427413b Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 13:58:44 +0200 Subject: [PATCH 21/50] move validation into account manager + func for get requests --- management/server/account.go | 6 +- management/server/http/pat_handler.go | 70 +++----- management/server/http/pat_handler_test.go | 40 ++++- management/server/mock_server/account_mock.go | 36 +++- management/server/user.go | 116 +++++++++++-- management/server/user_test.go | 160 +++++++++++++++--- 6 files changed, 322 insertions(+), 106 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 3db6a2fe0..ce2d2fc1e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -67,8 +67,10 @@ type AccountManager interface { GetNetworkMap(peerID string) (*NetworkMap, error) GetPeerNetwork(peerID string) (*Network, error) AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error) - AddPATToUser(accountID string, userID string, pat *PersonalAccessToken) error - DeletePAT(accountID string, userID string, tokenID string) error + CreatePAT(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) + DeletePAT(accountID string, executingUserID string, targetUserId string, tokenID string) error + GetPAT(accountID string, executingUserID string, targetUserId string, tokenID string) (*PersonalAccessToken, error) + GetAllPATs(accountID string, executingUserID string, targetUserId string) ([]*PersonalAccessToken, error) UpdatePeerSSHKey(peerID string, sshKey string) error GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetGroup(accountId, groupID string) (*Group, error) diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index d3e8b9ac5..d2398a7e1 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -46,17 +46,19 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - if userID != user.Id { - util.WriteErrorResponse("User not authorized to get tokens", http.StatusUnauthorized, w) + + pats, err := h.accountManager.GetAllPATs(account.Id, user.Id, userID) + if err != nil { + util.WriteError(err, w) return } - var pats []*api.PersonalAccessToken - for _, pat := range account.Users[userID].PATs { - pats = append(pats, toPATResponse(pat)) + var patResponse []*api.PersonalAccessToken + for _, pat := range pats { + patResponse = append(patResponse, toPATResponse(pat)) } - util.WriteJSONObject(w, pats) + util.WriteJSONObject(w, patResponse) } // GetToken is HTTP GET handler that returns a personal access token for the given user @@ -69,15 +71,11 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { } vars := mux.Vars(r) - userID := vars["userId"] - if len(userID) == 0 { + targetUserID := vars["userId"] + if len(targetUserID) == 0 { util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - if userID != user.Id { - util.WriteErrorResponse("User not authorized to get token", http.StatusUnauthorized, w) - return - } tokenID := vars["tokenId"] if len(tokenID) == 0 { @@ -85,15 +83,9 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { return } - user = account.Users[userID] - if user == nil { - util.WriteError(status.Errorf(status.NotFound, "user not found"), w) - return - } - - pat := user.PATs[tokenID] - if pat == nil { - util.WriteError(status.Errorf(status.NotFound, "PAT not found"), w) + pat, err := h.accountManager.GetPAT(account.Id, user.Id, targetUserID, tokenID) + if err != nil { + util.WriteError(err, w) return } @@ -110,15 +102,11 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { } vars := mux.Vars(r) - userID := vars["userId"] - if len(userID) == 0 { + targetUserID := vars["userId"] + if len(targetUserID) == 0 { util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - if userID != user.Id { - util.WriteErrorResponse("User not authorized to create token", http.StatusUnauthorized, w) - return - } var req api.PostApiUsersUserIdTokensJSONRequestBody err = json.NewDecoder(r.Body).Decode(&req) @@ -127,23 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { return } - if req.Name == "" { - util.WriteErrorResponse("name can't be empty", http.StatusBadRequest, w) - return - } - - if req.ExpiresIn < 1 || req.ExpiresIn > 365 { - util.WriteErrorResponse("expiration has to be between 1 and 365", http.StatusBadRequest, w) - return - } - - pat, err := server.CreateNewPAT(req.Name, req.ExpiresIn, user.Id) - if err != nil { - util.WriteError(err, w) - return - } - - err = h.accountManager.AddPATToUser(account.Id, userID, &pat.PersonalAccessToken) + pat, err := h.accountManager.CreatePAT(account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) if err != nil { util.WriteError(err, w) return @@ -162,15 +134,11 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { } vars := mux.Vars(r) - userID := vars["userId"] - if len(userID) == 0 { + targetUserID := vars["userId"] + if len(targetUserID) == 0 { util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - if userID != user.Id { - util.WriteErrorResponse("User not authorized to delete token", http.StatusUnauthorized, w) - return - } tokenID := vars["tokenId"] if len(tokenID) == 0 { @@ -178,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeletePAT(account.Id, userID, tokenID) + err = h.accountManager.DeletePAT(account.Id, user.Id, targetUserID, tokenID) if err != nil { util.WriteError(err, w) return diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index 7b83c5db8..de79f1006 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -63,31 +63,55 @@ var testAccount = &server.Account{ func initPATTestData() *PATHandler { return &PATHandler{ accountManager: &mock_server.MockAccountManager{ - AddPATToUserFunc: func(accountID string, userID string, pat *server.PersonalAccessToken) error { + CreatePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { if accountID != existingAccountID { - return status.Errorf(status.NotFound, "account with ID %s not found", accountID) + return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } - if userID != existingUserID { - return status.Errorf(status.NotFound, "user with ID %s not found", userID) + if targetUserID != existingUserID { + return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) } - return nil + return &server.PersonalAccessTokenGenerated{ + PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe", + PersonalAccessToken: server.PersonalAccessToken{}, + }, nil }, GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testAccount, testAccount.Users[existingUserID], nil }, - DeletePATFunc: func(accountID string, userID string, tokenID string) error { + DeletePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) error { if accountID != existingAccountID { return status.Errorf(status.NotFound, "account with ID %s not found", accountID) } - if userID != existingUserID { - return status.Errorf(status.NotFound, "user with ID %s not found", userID) + if targetUserID != existingUserID { + return status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) } if tokenID != existingTokenID { return status.Errorf(status.NotFound, "token with ID %s not found", tokenID) } return nil }, + GetPATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { + if accountID != existingAccountID { + return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) + } + if targetUserID != existingUserID { + return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) + } + if tokenID != existingTokenID { + return nil, status.Errorf(status.NotFound, "token with ID %s not found", tokenID) + } + return testAccount.Users[existingUserID].PATs[existingTokenID], nil + }, + GetAllPATsFunc: func(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { + if accountID != existingAccountID { + return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) + } + if targetUserID != existingUserID { + return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) + } + return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil + }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 71870fd84..53cd2d672 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -60,8 +60,10 @@ type MockAccountManager struct { SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) - AddPATToUserFunc func(accountID string, userID string, pat *server.PersonalAccessToken) error - DeletePATFunc func(accountID string, userID string, tokenID string) error + CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) + DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error + GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) + GetAllPATsFunc func(accountID string, executingUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error @@ -186,22 +188,38 @@ func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *s return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") } -// AddPATToUser mock implementation of AddPATToUser from server.AccountManager interface -func (am *MockAccountManager) AddPATToUser(accountID string, userID string, pat *server.PersonalAccessToken) error { - if am.AddPATToUserFunc != nil { - return am.AddPATToUserFunc(accountID, userID, pat) +// CreatePAT mock implementation of GetPAT from server.AccountManager interface +func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { + if am.CreatePATFunc != nil { + return am.CreatePATFunc(accountID, executingUserID, targetUserID, name, expiresIn) } - return status.Errorf(codes.Unimplemented, "method AddPATToUser is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented") } // DeletePAT mock implementation of DeletePAT from server.AccountManager interface -func (am *MockAccountManager) DeletePAT(accountID string, userID string, tokenID string) error { +func (am *MockAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error { if am.DeletePATFunc != nil { - return am.DeletePATFunc(accountID, userID, tokenID) + return am.DeletePATFunc(accountID, executingUserID, targetUserID, tokenID) } return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented") } +// GetPAT mock implementation of GetPAT from server.AccountManager interface +func (am *MockAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { + if am.GetPATFunc != nil { + return am.GetPATFunc(accountID, executingUserID, targetUserID, tokenID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented") +} + +// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface +func (am *MockAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { + if am.GetAllPATsFunc != nil { + return am.GetAllPATsFunc(accountID, executingUserID, targetUserID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented") +} + // GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface func (am *MockAccountManager) GetNetworkMap(peerKey string) (*server.NetworkMap, error) { if am.GetNetworkMapFunc != nil { diff --git a/management/server/user.go b/management/server/user.go index 572843fff..97e9cd0c7 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -193,37 +193,63 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, invite *Us } -// AddPATToUser takes the userID and the accountID the user belongs to and assigns a provided PersonalAccessToken to that user -func (am *DefaultAccountManager) AddPATToUser(accountID string, userID string, pat *PersonalAccessToken) error { +// CreatePAT creates a new PAT for the given user +func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() + if tokenName == "" { + return nil, status.Errorf(status.InvalidArgument, "token name can't be empty") + } + + if expiresIn < 1 || expiresIn > 365 { + return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") + } + + if executingUserID != targetUserId { + return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user") + } + account, err := am.Store.GetAccount(accountID) if err != nil { - return err + return nil, err } - user := account.Users[userID] - if user == nil { - return status.Errorf(status.NotFound, "user not found") + targetUser := account.Users[targetUserId] + if targetUser == nil { + return nil, status.Errorf(status.NotFound, "targetUser not found") } - user.PATs[pat.ID] = pat + pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id) + if err != nil { + return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) + } - return am.Store.SaveAccount(account) + targetUser.PATs[pat.ID] = &pat.PersonalAccessToken + + err = am.Store.SaveAccount(account) + if err != nil { + return nil, status.Errorf(status.Internal, "failed to save account: %v", err) + } + + return pat, nil } // DeletePAT deletes a specific PAT from a user -func (am *DefaultAccountManager) DeletePAT(accountID string, userID string, tokenID string) error { +func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) - if err != nil { - return err + if executingUserID != targetUserID { + return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user") } - user := account.Users[userID] + account, err := am.Store.GetAccount(accountID) + if err != nil { + return status.Errorf(status.NotFound, "account not found: %s", err) + } + + user := account.Users[targetUserID] if user == nil { return status.Errorf(status.NotFound, "user not found") } @@ -235,15 +261,73 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, userID string, toke err = am.Store.DeleteTokenID2UserIDIndex(pat.ID) if err != nil { - return err + return status.Errorf(status.Internal, "Failed to delete token id index: %s", err) } err = am.Store.DeleteHashedPAT2TokenIDIndex(pat.HashedToken) if err != nil { - return err + return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err) } delete(user.PATs, tokenID) - return am.Store.SaveAccount(account) + err = am.Store.SaveAccount(account) + if err != nil { + return status.Errorf(status.Internal, "Failed to save account: %s", err) + } + return nil +} + +// GetPAT returns a specific PAT from a user +func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + if executingUserID != targetUserID { + return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") + } + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, status.Errorf(status.NotFound, "account not found: %s", err) + } + + user := account.Users[targetUserID] + if user == nil { + return nil, status.Errorf(status.NotFound, "user not found") + } + + pat := user.PATs[tokenID] + if pat == nil { + return nil, status.Errorf(status.NotFound, "PAT not found") + } + + return pat, nil +} + +// GetAllPATs returns all PATs for a user +func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + if executingUserID != targetUserID { + return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") + } + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, status.Errorf(status.NotFound, "account not found: %s", err) + } + + user := account.Users[targetUserID] + if user == nil { + return nil, status.Errorf(status.NotFound, "user not found") + } + + var pats []*PersonalAccessToken + for _, pat := range user.PATs { + pats = append(pats, pat) + } + + return pats, nil } // SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error. diff --git a/management/server/user_test.go b/management/server/user_test.go index ffbb71282..1dd12e57b 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -7,13 +7,20 @@ import ( ) const ( - mockAccountID = "accountID" - mockUserID = "userID" - mockTokenID = "tokenID" - mockToken = "SoMeHaShEdToKeN" + mockAccountID = "accountID" + mockUserID = "userID" + mockTargetUserId = "targetUserID" + mockTokenID1 = "tokenID1" + mockToken1 = "SoMeHaShEdToKeN1" + mockTokenID2 = "tokenID2" + mockToken2 = "SoMeHaShEdToKeN2" + mockTokenName = "tokenName" + mockEmptyTokenName = "" + mockExpiresIn = 7 + mockWrongExpiresIn = 4506 ) -func TestUser_AddPATToUser(t *testing.T) { +func TestUser_CreatePAT_ForSameUser(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") @@ -26,24 +33,19 @@ func TestUser_AddPATToUser(t *testing.T) { Store: store, } - pat := PersonalAccessToken{ - ID: mockTokenID, - HashedToken: mockToken, - } - - err = am.AddPATToUser(mockAccountID, mockUserID, &pat) + pat, err := am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } fileStore := am.Store.(*FileStore) - tokenID := fileStore.HashedPAT2TokenID[mockToken] + tokenID := fileStore.HashedPAT2TokenID[pat.HashedToken] if tokenID == "" { t.Fatal("GetTokenIDByHashedToken failed after adding PAT") } - assert.Equal(t, mockTokenID, tokenID) + assert.Equal(t, pat.ID, tokenID) userID := fileStore.TokenID2UserID[tokenID] if userID == "" { @@ -52,15 +54,66 @@ func TestUser_AddPATToUser(t *testing.T) { assert.Equal(t, mockUserID, userID) } +func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + _, err = am.CreatePAT(mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) + assert.Errorf(t, err, "Creating PAT for different user should thorw error") +} + +func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + _, err = am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn) + assert.Errorf(t, err, "Wrong expiration should thorw error") +} + +func TestUser_CreatePAT_WithEmptyName(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + _, err = am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) + assert.Errorf(t, err, "Wrong expiration should thorw error") +} + func TestUser_DeletePAT(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ Id: mockUserID, PATs: map[string]*PersonalAccessToken{ - mockTokenID: { - ID: mockTokenID, - HashedToken: mockToken, + mockTokenID1: { + ID: mockTokenID1, + HashedToken: mockToken1, }, }, } @@ -73,12 +126,79 @@ func TestUser_DeletePAT(t *testing.T) { Store: store, } - err = am.DeletePAT(mockAccountID, mockUserID, mockTokenID) + err = am.DeletePAT(mockAccountID, mockUserID, mockUserID, mockTokenID1) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } - assert.Nil(t, store.Accounts[mockAccountID].Users[mockUserID].PATs[mockTokenID]) - assert.Empty(t, store.HashedPAT2TokenID[mockToken]) - assert.Empty(t, store.TokenID2UserID[mockTokenID]) + assert.Nil(t, store.Accounts[mockAccountID].Users[mockUserID].PATs[mockTokenID1]) + assert.Empty(t, store.HashedPAT2TokenID[mockToken1]) + assert.Empty(t, store.TokenID2UserID[mockTokenID1]) +} + +func TestUser_GetPAT(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + account.Users[mockUserID] = &User{ + Id: mockUserID, + PATs: map[string]*PersonalAccessToken{ + mockTokenID1: { + ID: mockTokenID1, + HashedToken: mockToken1, + }, + }, + } + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + pat, err := am.GetPAT(mockAccountID, mockUserID, mockUserID, mockTokenID1) + if err != nil { + t.Fatalf("Error when adding PAT to user: %s", err) + } + + assert.Equal(t, mockTokenID1, pat.ID) + assert.Equal(t, mockToken1, pat.HashedToken) +} + +func TestUser_GetAllPATs(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + account.Users[mockUserID] = &User{ + Id: mockUserID, + PATs: map[string]*PersonalAccessToken{ + mockTokenID1: { + ID: mockTokenID1, + HashedToken: mockToken1, + }, + mockTokenID2: { + ID: mockTokenID2, + HashedToken: mockToken2, + }, + }, + } + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + pats, err := am.GetAllPATs(mockAccountID, mockUserID, mockUserID) + if err != nil { + t.Fatalf("Error when adding PAT to user: %s", err) + } + + assert.Equal(t, 2, len(pats)) + assert.Equal(t, mockTokenID1, pats[0].ID) + assert.Equal(t, mockToken1, pats[0].HashedToken) + assert.Equal(t, mockTokenID2, pats[1].ID) + assert.Equal(t, mockToken2, pats[1].HashedToken) } From a7519859bccf2f4a02c672437ed3bd4196ace724 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 14:15:44 +0200 Subject: [PATCH 22/50] fix test --- management/server/user_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/management/server/user_test.go b/management/server/user_test.go index 1dd12e57b..238aa2bff 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -197,8 +197,4 @@ func TestUser_GetAllPATs(t *testing.T) { } assert.Equal(t, 2, len(pats)) - assert.Equal(t, mockTokenID1, pats[0].ID) - assert.Equal(t, mockToken1, pats[0].HashedToken) - assert.Equal(t, mockTokenID2, pats[1].ID) - assert.Equal(t, mockToken2, pats[1].HashedToken) } From 5e2f66d59142750c2ce9491af4b5d2dd607e12c4 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 15:23:24 +0200 Subject: [PATCH 23/50] fix codacy --- management/server/account.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index ce2d2fc1e..2c146b3ad 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -67,10 +67,10 @@ type AccountManager interface { GetNetworkMap(peerID string) (*NetworkMap, error) GetPeerNetwork(peerID string) (*Network, error) AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error) - CreatePAT(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) - DeletePAT(accountID string, executingUserID string, targetUserId string, tokenID string) error - GetPAT(accountID string, executingUserID string, targetUserId string, tokenID string) (*PersonalAccessToken, error) - GetAllPATs(accountID string, executingUserID string, targetUserId string) ([]*PersonalAccessToken, error) + CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) + DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error + GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) + GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) UpdatePeerSSHKey(peerID string, sshKey string) error GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetGroup(accountId, groupID string) (*Group, error) From 6c8bb6063278a21b78915aa1e0ee9c3407861b91 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 16:06:46 +0200 Subject: [PATCH 24/50] fix merge --- management/server/http/middleware/jwt.go | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 management/server/http/middleware/jwt.go diff --git a/management/server/http/middleware/jwt.go b/management/server/http/middleware/jwt.go deleted file mode 100644 index e69de29bb..000000000 From e869882da11bb9e77627219264ad323f1c7f43f6 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 16:14:51 +0200 Subject: [PATCH 25/50] fix merge --- management/server/grpcserver.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 958b6729e..0c8dad246 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -46,10 +46,10 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager return nil, err } - var jwtMiddleware *middleware.JWTMiddleware + var jwtValidator *jwtclaims.JWTValidator if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { - jwtMiddleware, err = middleware.NewJwtMiddleware( + jwtValidator, err = jwtclaims.NewJWTValidator( config.HttpConfig.AuthIssuer, config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation) @@ -87,7 +87,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager accountManager: accountManager, config: config, turnCredentialsManager: turnCredentialsManager, - jwtMiddleware: jwtMiddleware, + jwtValidator: jwtValidator, jwtClaimsExtractor: jwtClaimsExtractor, appMetrics: appMetrics, }, nil @@ -188,11 +188,11 @@ func (s *GRPCServer) cancelPeerRoutines(peer *Peer) { } func (s *GRPCServer) validateToken(jwtToken string) (string, error) { - if s.jwtMiddleware == nil { - return "", status.Error(codes.Internal, "no jwt middleware set") + if s.jwtValidator == nil { + return "", status.Error(codes.Internal, "no jwt validator set") } - token, err := s.jwtMiddleware.ValidateAndParse(jwtToken) + token, err := s.jwtValidator.ValidateAndParse(jwtToken) if err != nil { return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) } From 2a799957064de411d4d962d94be418071ea601d1 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 16:22:15 +0200 Subject: [PATCH 26/50] fix linter --- management/server/account.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index ae61d34ec..be51e745d 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1128,7 +1128,6 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { defer unlock() user, err := am.Store.GetUserByTokenID(tokenID) - log.Debugf("User: %v", user) if err != nil { return err } @@ -1145,9 +1144,7 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { pat.LastUsed = time.Now() - am.Store.SaveAccount(account) - - return nil + return am.Store.SaveAccount(account) } // GetAccountFromPAT returns Account and User associated with a personal access token From 1343a3f00ebbdac6e4eb91fa66076986400daf40 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 16:43:39 +0200 Subject: [PATCH 27/50] add test + codacy --- management/server/account_test.go | 38 +++++++++++++++++++ .../server/http/middleware/auth_middleware.go | 8 ++-- .../http/middleware/auth_midleware_test.go | 1 - management/server/jwtclaims/extractor.go | 20 +++++----- management/server/jwtclaims/extractor_test.go | 12 +++--- 5 files changed, 59 insertions(+), 20 deletions(-) delete mode 100644 management/server/http/middleware/auth_midleware_test.go diff --git a/management/server/account_test.go b/management/server/account_test.go index 25d501c8b..f21c93f0e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -495,6 +495,44 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat) } +func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { + store := newStore(t) + account := newAccountWithId("account_id", "testuser", "") + + token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" + hashedToken := sha256.Sum256([]byte(token)) + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) + account.Users["someUser"] = &User{ + Id: "someUser", + PATs: map[string]*PersonalAccessToken{ + "tokenId": { + ID: "tokenId", + HashedToken: encodedHashedToken, + LastUsed: time.Time{}, + }, + }, + } + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + err = am.MarkPATUsed("tokenId") + if err != nil { + t.Fatalf("Error when marking PAT used: %s", err) + } + + account, err = am.Store.GetAccount("account_id") + if err != nil { + t.Fatalf("Error when getting account: %s", err) + } + assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero()) +} + func TestAccountManager_PrivateAccount(t *testing.T) { manager, err := createManager(t) if err != nil { diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 0b1478ef3..54889466e 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -124,10 +124,10 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ } claimMaps := jwt.MapClaims{} - claimMaps[jwtclaims.UserIDClaim] = user.Id - claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id - claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain - claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory + claimMaps[string(jwtclaims.UserIDClaim)] = user.Id + claimMaps[m.audience+string(jwtclaims.AccountIDSuffix)] = account.Id + claimMaps[m.audience+string(jwtclaims.DomainIDSuffix)] = account.Domain + claimMaps[m.audience+string(jwtclaims.DomainCategorySuffix)] = account.DomainCategory jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) // Update the current request with the new context information. diff --git a/management/server/http/middleware/auth_midleware_test.go b/management/server/http/middleware/auth_midleware_test.go deleted file mode 100644 index c870d7c16..000000000 --- a/management/server/http/middleware/auth_midleware_test.go +++ /dev/null @@ -1 +0,0 @@ -package middleware diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index 9d60da335..5063d7b91 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -6,12 +6,14 @@ import ( "github.com/golang-jwt/jwt" ) +type key string + const ( - TokenUserProperty = "user" - AccountIDSuffix = "wt_account_id" - DomainIDSuffix = "wt_account_domain" - DomainCategorySuffix = "wt_account_domain_category" - UserIDClaim = "sub" + TokenUserProperty key = "user" + AccountIDSuffix key = "wt_account_id" + DomainIDSuffix key = "wt_account_domain" + DomainCategorySuffix key = "wt_account_domain_category" + UserIDClaim key = "sub" ) // Extract function type @@ -60,7 +62,7 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor { ce.FromRequestContext = ce.fromRequestContext } if ce.userIDClaim == "" { - ce.userIDClaim = UserIDClaim + ce.userIDClaim = string(UserIDClaim) } return ce } @@ -74,15 +76,15 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims { return jwtClaims } jwtClaims.UserId = userID - accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix] + accountIDClaim, ok := claims[c.authAudience+string(AccountIDSuffix)] if ok { jwtClaims.AccountId = accountIDClaim.(string) } - domainClaim, ok := claims[c.authAudience+DomainIDSuffix] + domainClaim, ok := claims[c.authAudience+string(DomainIDSuffix)] if ok { jwtClaims.Domain = domainClaim.(string) } - domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix] + domainCategoryClaim, ok := claims[c.authAudience+string(DomainCategorySuffix)] if ok { jwtClaims.DomainCategory = domainCategoryClaim.(string) } diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go index d8acd79b6..d8476e039 100644 --- a/management/server/jwtclaims/extractor_test.go +++ b/management/server/jwtclaims/extractor_test.go @@ -12,21 +12,21 @@ import ( func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request { claimMaps := jwt.MapClaims{} if claims.UserId != "" { - claimMaps[UserIDClaim] = claims.UserId + claimMaps[string(UserIDClaim)] = claims.UserId } if claims.AccountId != "" { - claimMaps[audiance+AccountIDSuffix] = claims.AccountId + claimMaps[audiance+string(AccountIDSuffix)] = claims.AccountId } if claims.Domain != "" { - claimMaps[audiance+DomainIDSuffix] = claims.Domain + claimMaps[audiance+string(DomainIDSuffix)] = claims.Domain } if claims.DomainCategory != "" { - claimMaps[audiance+DomainCategorySuffix] = claims.DomainCategory + claimMaps[audiance+string(DomainCategorySuffix)] = claims.DomainCategory } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) require.NoError(t, err, "creating testing request failed") - testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint + testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint return testRequest } @@ -124,7 +124,7 @@ func TestExtractClaimsSetOptions(t *testing.T) { t.Error("audience should be empty") return } - if c.extractor.userIDClaim != UserIDClaim { + if c.extractor.userIDClaim != string(UserIDClaim) { t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim) return } From 454240ca05f0e3ff106e4c596c79ff937aaca42c Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 17:32:44 +0200 Subject: [PATCH 28/50] comments for codacy --- management/server/account.go | 1 + .../server/http/middleware/auth_middleware.go | 17 ++++++++++++----- management/server/jwtclaims/extractor.go | 13 +++++++++---- management/server/jwtclaims/jwtValidator.go | 3 +++ 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index be51e745d..27bf5606e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1123,6 +1123,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e return nil } +// MarkPATUsed marks a personal access token as used func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { unlock := am.Store.AcquireGlobalLock() defer unlock() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 54889466e..d211a6ef2 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -17,8 +17,13 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) +// GetAccountFromPATFunc function type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + +// ValidateAndParseTokenFunc function type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error) + +// MarkPATUsedFunc function type MarkPATUsedFunc func(token string) error // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens @@ -29,8 +34,10 @@ type AuthMiddleware struct { audience string } +type key string + const ( - userProperty = "user" + userProperty key = "user" ) // NewAuthMiddleware instance constructor @@ -44,13 +51,13 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse } // Handler method of the middleware which authenticates a user either by JWT claims or by PAT -func (a *AuthMiddleware) Handler(h http.Handler) http.Handler { +func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.Split(r.Header.Get("Authorization"), " ") authType := auth[0] switch strings.ToLower(authType) { case "bearer": - err := a.CheckJWTFromRequest(w, r) + err := m.CheckJWTFromRequest(w, r) if err != nil { log.Debugf("Error when validating JWT claims: %s", err.Error()) util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) @@ -58,7 +65,7 @@ func (a *AuthMiddleware) Handler(h http.Handler) http.Handler { } h.ServeHTTP(w, r) case "token": - err := a.CheckPATFromRequest(w, r) + err := m.CheckPATFromRequest(w, r) if err != nil { log.Debugf("Error when validating PAT claims: %s", err.Error()) util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) @@ -93,7 +100,7 @@ func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Requ // If we get here, everything worked and we can set the // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) // nolint + newRequest := r.WithContext(context.WithValue(r.Context(), string(userProperty), validatedToken)) // nolint // Update the current request with the new context information. *r = *newRequest return nil diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index 5063d7b91..b5f30b8d6 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -9,11 +9,16 @@ import ( type key string const ( - TokenUserProperty key = "user" - AccountIDSuffix key = "wt_account_id" - DomainIDSuffix key = "wt_account_domain" + // TokenUserProperty key for the user property in the request context + TokenUserProperty key = "user" + // AccountIDSuffix suffix for the account id claim + AccountIDSuffix key = "wt_account_id" + // DomainIDSuffix suffix for the domain id claim + DomainIDSuffix key = "wt_account_domain" + // DomainCategorySuffix suffix for the domain category claim DomainCategorySuffix key = "wt_account_domain_category" - UserIDClaim key = "sub" + // UserIDClaim claim for the user id + UserIDClaim key = "sub" ) // Extract function type diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index d324c5ab3..ee9513c57 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -58,10 +58,12 @@ type JSONWebKey struct { X5c []string `json:"x5c"` } +// JWTValidator struct to handle token validation and parsing type JWTValidator struct { options Options } +// NewJWTValidator constructor func NewJWTValidator(issuer string, audience string, keysLocation string) (*JWTValidator, error) { keys, err := getPemKeys(keysLocation) if err != nil { @@ -102,6 +104,7 @@ func NewJWTValidator(issuer string, audience string, keysLocation string) (*JWTV }, nil } +// ValidateAndParse validates the token and returns the parsed token func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { // If the token is empty... if token == "" { From e08af7fcdff28219419e566c90045f7aee3827dc Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 17:46:21 +0200 Subject: [PATCH 29/50] codacy --- management/server/http/middleware/auth_middleware.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index d211a6ef2..2901ccfbd 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -100,7 +100,7 @@ func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Requ // If we get here, everything worked and we can set the // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), string(userProperty), validatedToken)) // nolint + newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) // nolint // Update the current request with the new context information. *r = *newRequest return nil From f273fe9f519621140420e5bb9f8a200c5de899b3 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 18:54:55 +0200 Subject: [PATCH 30/50] revert codacy --- .../server/http/middleware/auth_middleware.go | 12 +++++------- management/server/jwtclaims/extractor.go | 18 ++++++++---------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 2901ccfbd..aeebf2593 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -34,10 +34,8 @@ type AuthMiddleware struct { audience string } -type key string - const ( - userProperty key = "user" + userProperty = "user" ) // NewAuthMiddleware instance constructor @@ -131,10 +129,10 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ } claimMaps := jwt.MapClaims{} - claimMaps[string(jwtclaims.UserIDClaim)] = user.Id - claimMaps[m.audience+string(jwtclaims.AccountIDSuffix)] = account.Id - claimMaps[m.audience+string(jwtclaims.DomainIDSuffix)] = account.Domain - claimMaps[m.audience+string(jwtclaims.DomainCategorySuffix)] = account.DomainCategory + claimMaps[jwtclaims.UserIDClaim] = user.Id + claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id + claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain + claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) // Update the current request with the new context information. diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index b5f30b8d6..3bd518d00 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -6,19 +6,17 @@ import ( "github.com/golang-jwt/jwt" ) -type key string - const ( // TokenUserProperty key for the user property in the request context - TokenUserProperty key = "user" + TokenUserProperty = "user" // AccountIDSuffix suffix for the account id claim - AccountIDSuffix key = "wt_account_id" + AccountIDSuffix = "wt_account_id" // DomainIDSuffix suffix for the domain id claim - DomainIDSuffix key = "wt_account_domain" + DomainIDSuffix = "wt_account_domain" // DomainCategorySuffix suffix for the domain category claim - DomainCategorySuffix key = "wt_account_domain_category" + DomainCategorySuffix = "wt_account_domain_category" // UserIDClaim claim for the user id - UserIDClaim key = "sub" + UserIDClaim = "sub" ) // Extract function type @@ -81,15 +79,15 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims { return jwtClaims } jwtClaims.UserId = userID - accountIDClaim, ok := claims[c.authAudience+string(AccountIDSuffix)] + accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix] if ok { jwtClaims.AccountId = accountIDClaim.(string) } - domainClaim, ok := claims[c.authAudience+string(DomainIDSuffix)] + domainClaim, ok := claims[c.authAudience+DomainIDSuffix] if ok { jwtClaims.Domain = domainClaim.(string) } - domainCategoryClaim, ok := claims[c.authAudience+string(DomainCategorySuffix)] + domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix] if ok { jwtClaims.DomainCategory = domainCategoryClaim.(string) } From ce775d59aeb7f14bccbe9de45a7189d77cf628f6 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 18:59:35 +0200 Subject: [PATCH 31/50] revert codacy --- management/server/jwtclaims/extractor.go | 4 ++-- management/server/jwtclaims/extractor_test.go | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index 3bd518d00..9aa00a004 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -19,7 +19,7 @@ const ( UserIDClaim = "sub" ) -// Extract function type +// ExtractClaims Extract function type type ExtractClaims func(r *http.Request) AuthorizationClaims // ClaimsExtractor struct that holds the extract function @@ -65,7 +65,7 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor { ce.FromRequestContext = ce.fromRequestContext } if ce.userIDClaim == "" { - ce.userIDClaim = string(UserIDClaim) + ce.userIDClaim = UserIDClaim } return ce } diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go index d8476e039..53f8818b1 100644 --- a/management/server/jwtclaims/extractor_test.go +++ b/management/server/jwtclaims/extractor_test.go @@ -12,16 +12,16 @@ import ( func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request { claimMaps := jwt.MapClaims{} if claims.UserId != "" { - claimMaps[string(UserIDClaim)] = claims.UserId + claimMaps[UserIDClaim] = claims.UserId } if claims.AccountId != "" { - claimMaps[audiance+string(AccountIDSuffix)] = claims.AccountId + claimMaps[audiance+AccountIDSuffix] = claims.AccountId } if claims.Domain != "" { - claimMaps[audiance+string(DomainIDSuffix)] = claims.Domain + claimMaps[audiance+DomainIDSuffix] = claims.Domain } if claims.DomainCategory != "" { - claimMaps[audiance+string(DomainCategorySuffix)] = claims.DomainCategory + claimMaps[audiance+DomainCategorySuffix] = claims.DomainCategory } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) @@ -124,7 +124,7 @@ func TestExtractClaimsSetOptions(t *testing.T) { t.Error("audience should be empty") return } - if c.extractor.userIDClaim != string(UserIDClaim) { + if c.extractor.userIDClaim != UserIDClaim { t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim) return } From ca1dc5ac885bfc9b4804172b560335f8d58edb65 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Thu, 30 Mar 2023 19:03:44 +0200 Subject: [PATCH 32/50] disable access control for token endpoint --- .../server/http/middleware/access_control.go | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index 5e56f75ab..f1ab898a8 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -2,6 +2,9 @@ package middleware import ( "net/http" + "regexp" + + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" @@ -34,12 +37,23 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims := a.claimsExtract.FromRequestContext(r) - ok, err := a.isUserAdmin(claims) + ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path) + if err != nil { + log.Debugf("Regex failed") + util.WriteError(status.Errorf(status.Internal, ""), w) + return + } + if ok { + log.Debugf("Valid Path") + h.ServeHTTP(w, r) + return + } + + ok, err = a.isUserAdmin(claims) if err != nil { util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) return } - if !ok { switch r.Method { case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: From 32c96c15b85d6675c22695fb151febf725b99dcf Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Fri, 31 Mar 2023 10:30:05 +0200 Subject: [PATCH 33/50] disable linter errors by comment --- management/server/http/middleware/auth_middleware.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index aeebf2593..c3f9361dd 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -98,7 +98,7 @@ func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Requ // If we get here, everything worked and we can set the // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) // nolint + newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint // Update the current request with the new context information. *r = *newRequest return nil @@ -134,7 +134,7 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) - newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) + newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint // Update the current request with the new context information. *r = *newRequest return nil From 110067c00fc4c2a81dce37911ec081d4fa54dcf8 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Fri, 31 Mar 2023 12:03:53 +0200 Subject: [PATCH 34/50] change order for access control checks and aquire account lock after global lock --- management/server/account.go | 10 ++++++- .../server/http/middleware/access_control.go | 27 ++++++++++--------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 27bf5606e..78c9237b8 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1126,7 +1126,6 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e // MarkPATUsed marks a personal access token as used func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { unlock := am.Store.AcquireGlobalLock() - defer unlock() user, err := am.Store.GetUserByTokenID(tokenID) if err != nil { @@ -1138,6 +1137,15 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { return err } + unlock() + unlock = am.Store.AcquireAccountLock(account.Id) + defer unlock() + + account, err = am.Store.GetAccountByUser(user.Id) + if err != nil { + return err + } + pat, ok := account.Users[user.Id].PATs[tokenID] if !ok { return fmt.Errorf("token not found") diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index f1ab898a8..5f8389dfa 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -37,19 +37,7 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims := a.claimsExtract.FromRequestContext(r) - ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path) - if err != nil { - log.Debugf("Regex failed") - util.WriteError(status.Errorf(status.Internal, ""), w) - return - } - if ok { - log.Debugf("Valid Path") - h.ServeHTTP(w, r) - return - } - - ok, err = a.isUserAdmin(claims) + ok, err := a.isUserAdmin(claims) if err != nil { util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) return @@ -57,6 +45,19 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { if !ok { switch r.Method { case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: + + ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path) + if err != nil { + log.Debugf("Regex failed") + util.WriteError(status.Errorf(status.Internal, ""), w) + return + } + if ok { + log.Debugf("Valid Path") + h.ServeHTTP(w, r) + return + } + util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w) return } From 2eaf4aa8d7b8972b95772374551466445d43d49c Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Fri, 31 Mar 2023 12:44:22 +0200 Subject: [PATCH 35/50] add test for auth middleware --- .../http/middleware/auth_middleware_test.go | 123 ++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 management/server/http/middleware/auth_middleware_test.go diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go new file mode 100644 index 000000000..e77beb21f --- /dev/null +++ b/management/server/http/middleware/auth_middleware_test.go @@ -0,0 +1,123 @@ +package middleware + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt" + + "github.com/netbirdio/netbird/management/server" +) + +const ( + audience = "audience" + accountID = "accountID" + domain = "domain" + userID = "userID" + tokenID = "tokenID" + PAT = "PAT" + JWT = "JWT" + wrongToken = "wrongToken" +) + +var testAccount = &server.Account{ + Id: accountID, + Domain: domain, + Users: map[string]*server.User{ + userID: { + Id: userID, + PATs: map[string]*server.PersonalAccessToken{ + tokenID: { + ID: tokenID, + Name: "My first token", + HashedToken: "someHash", + ExpirationDate: time.Now().AddDate(0, 0, 7), + CreatedBy: userID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + }, + }, + }, +} + +func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { + if token == PAT { + return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil + } + return nil, nil, nil, fmt.Errorf("PAT invalid") +} + +func mockValidateAndParseToken(token string) (*jwt.Token, error) { + if token == JWT { + return &jwt.Token{}, nil + } + return nil, fmt.Errorf("JWT invalid") +} + +func mockMarkPATUsed(token string) error { + if token == tokenID { + return nil + } + return fmt.Errorf("Should never get reached") +} + +func TestAccounts_AccountsHandler(t *testing.T) { + tt := []struct { + name string + authHeader string + expectedStatusCode int + }{ + { + name: "Valid PAT Token", + authHeader: "Token " + PAT, + expectedStatusCode: 200, + }, + { + name: "Invalid PAT Token", + authHeader: "Token " + wrongToken, + expectedStatusCode: 401, + }, + { + name: "Valid JWT Token", + authHeader: "Bearer " + JWT, + expectedStatusCode: 200, + }, + { + name: "Invalid JWT Token", + authHeader: "Bearer " + wrongToken, + expectedStatusCode: 401, + }, + { + name: "Basic Auth", + authHeader: "Basic " + PAT, + expectedStatusCode: 401, + }, + } + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // do nothing + }) + + authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience) + + handlerToTest := authMiddleware.Handler(nextHandler) + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set("Authorization", tc.authHeader) + rec := httptest.NewRecorder() + + handlerToTest.ServeHTTP(rec, req) + + if rec.Result().StatusCode != tc.expectedStatusCode { + t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, rec.Result().StatusCode) + } + }) + } + +} From 931c20c8fe6a1c2c7aca4c80a58410e6bd3bdb85 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Fri, 31 Mar 2023 12:45:10 +0200 Subject: [PATCH 36/50] fix test name --- management/server/http/middleware/auth_middleware_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index e77beb21f..5a5558fa5 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -65,7 +65,7 @@ func mockMarkPATUsed(token string) error { return fmt.Errorf("Should never get reached") } -func TestAccounts_AccountsHandler(t *testing.T) { +func TestAuthMiddleware_Handler(t *testing.T) { tt := []struct { name string authHeader string From b2da0ae70facdb97fabfd394759dcbade67b2778 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Fri, 31 Mar 2023 17:41:22 +0200 Subject: [PATCH 37/50] add activity events on PAT creation and deletion --- management/server/activity/codes.go | 16 ++++++++++++++++ management/server/user.go | 7 +++++++ 2 files changed, 23 insertions(+) diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index a4a46439d..dacac4129 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -83,6 +83,10 @@ const ( AccountPeerLoginExpirationDisabled // AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account AccountPeerLoginExpirationDurationUpdated + // PersonalAccessTokenCreated indicates that a user created a personal access token + PersonalAccessTokenCreated + // PersonalAccessTokenDeleted indicates that a user deleted a personal access token + PersonalAccessTokenDeleted ) const ( @@ -168,6 +172,10 @@ const ( AccountPeerLoginExpirationDisabledMessage string = "Peer login expiration disabled for the account" // AccountPeerLoginExpirationDurationUpdatedMessage is a human-readable text message of the AccountPeerLoginExpirationDurationUpdated activity AccountPeerLoginExpirationDurationUpdatedMessage string = "Peer login expiration duration updated" + // PersonalAccessTokenCreatedMessage is a human-readable text message of the PersonalAccessTokenCreated activity + PersonalAccessTokenCreatedMessage string = "Personal access token created" + // PersonalAccessTokenDeletedMessage is a human-readable text message of the PersonalAccessTokenDeleted activity + PersonalAccessTokenDeletedMessage string = "Personal access token deleted" ) // Activity that triggered an Event @@ -258,6 +266,10 @@ func (a Activity) Message() string { return AccountPeerLoginExpirationDisabledMessage case AccountPeerLoginExpirationDurationUpdated: return AccountPeerLoginExpirationDurationUpdatedMessage + case PersonalAccessTokenCreated: + return PersonalAccessTokenCreatedMessage + case PersonalAccessTokenDeleted: + return PersonalAccessTokenDeletedMessage default: return "UNKNOWN_ACTIVITY" } @@ -348,6 +360,10 @@ func (a Activity) StringCode() string { return "account.setting.peer.login.expiration.enable" case AccountPeerLoginExpirationDisabled: return "account.setting.peer.login.expiration.disable" + case PersonalAccessTokenCreated: + return "personal.access.token.create" + case PersonalAccessTokenDeleted: + return "personal.access.token.delete" default: return "UNKNOWN_ACTIVITY" } diff --git a/management/server/user.go b/management/server/user.go index 97e9cd0c7..692a2833a 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -232,6 +232,9 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str return nil, status.Errorf(status.Internal, "failed to save account: %v", err) } + meta := map[string]any{"name": pat.Name} + am.storeEvent(executingUserID, targetUserId, accountID, activity.PersonalAccessTokenCreated, meta) + return pat, nil } @@ -267,6 +270,10 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str if err != nil { return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err) } + + meta := map[string]any{"name": pat.Name} + am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) + delete(user.PATs, tokenID) err = am.Store.SaveAccount(account) From d3de03596113ea62df42b3a19c04dbc856692965 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Sat, 1 Apr 2023 11:04:21 +0200 Subject: [PATCH 38/50] error responses always lower case + duplicate error response fix --- management/server/http/middleware/auth_middleware.go | 8 +++----- management/server/http/util/util.go | 3 ++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index c3f9361dd..a8c81012a 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -58,7 +58,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { err := m.CheckJWTFromRequest(w, r) if err != nil { log.Debugf("Error when validating JWT claims: %s", err.Error()) - util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) + util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) return } h.ServeHTTP(w, r) @@ -66,12 +66,12 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { err := m.CheckPATFromRequest(w, r) if err != nil { log.Debugf("Error when validating PAT claims: %s", err.Error()) - util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) + util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) return } h.ServeHTTP(w, r) default: - util.WriteError(status.Errorf(status.Unauthorized, "No valid authentication provided"), w) + util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w) return } }) @@ -115,11 +115,9 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ account, user, pat, err := m.getAccountFromPAT(token) if err != nil { - util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) return fmt.Errorf("invalid Token: %w", err) } if time.Now().After(pat.ExpirationDate) { - util.WriteError(status.Errorf(status.Unauthorized, "Token expired"), w) return fmt.Errorf("token expired") } diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index c40daa1a3..407443251 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "strings" "time" log "github.com/sirupsen/logrus" @@ -99,7 +100,7 @@ func WriteError(err error, w http.ResponseWriter) { httpStatus = http.StatusUnauthorized default: } - msg = err.Error() + msg = strings.ToLower(err.Error()) } else { unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error()) log.Error(unhandledMSG) From 45badd2c39b84198ac9b0776dbf17a296dd88d84 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Sat, 1 Apr 2023 11:11:30 +0200 Subject: [PATCH 39/50] add event store to user tests --- management/server/user_test.go | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/management/server/user_test.go b/management/server/user_test.go index 238aa2bff..29e6bc2bc 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -4,6 +4,8 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/activity" ) const ( @@ -30,7 +32,8 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { } am := DefaultAccountManager{ - Store: store, + Store: store, + eventStore: &activity.InMemoryEventStore{}, } pat, err := am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) @@ -64,7 +67,8 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { } am := DefaultAccountManager{ - Store: store, + Store: store, + eventStore: &activity.InMemoryEventStore{}, } _, err = am.CreatePAT(mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) @@ -81,7 +85,8 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { } am := DefaultAccountManager{ - Store: store, + Store: store, + eventStore: &activity.InMemoryEventStore{}, } _, err = am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn) @@ -98,7 +103,8 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) { } am := DefaultAccountManager{ - Store: store, + Store: store, + eventStore: &activity.InMemoryEventStore{}, } _, err = am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) @@ -123,7 +129,8 @@ func TestUser_DeletePAT(t *testing.T) { } am := DefaultAccountManager{ - Store: store, + Store: store, + eventStore: &activity.InMemoryEventStore{}, } err = am.DeletePAT(mockAccountID, mockUserID, mockUserID, mockTokenID1) @@ -154,7 +161,8 @@ func TestUser_GetPAT(t *testing.T) { } am := DefaultAccountManager{ - Store: store, + Store: store, + eventStore: &activity.InMemoryEventStore{}, } pat, err := am.GetPAT(mockAccountID, mockUserID, mockUserID, mockTokenID1) @@ -188,7 +196,8 @@ func TestUser_GetAllPATs(t *testing.T) { } am := DefaultAccountManager{ - Store: store, + Store: store, + eventStore: &activity.InMemoryEventStore{}, } pats, err := am.GetAllPATs(mockAccountID, mockUserID, mockUserID) From 5dc0ff42a5fe01592b4bcdbd7aeb0eafb3713f25 Mon Sep 17 00:00:00 2001 From: Givi Khojanashvili Date: Sat, 1 Apr 2023 14:02:08 +0400 Subject: [PATCH 40/50] Fix broken auto-generated Rego rule (#769) Default Rego policy generated from the rules in some cases is broken. This change fixes the Rego template for rules to generate policies. Also, file store load constantly regenerates policy objects from rules. It allows updating/fixing of the default Rego template during releases. --- go.mod | 2 +- management/server/file_store.go | 26 +- management/server/policy.go | 64 ++-- management/server/policy_test.go | 281 +++++++++++++++--- management/server/rego/default_policy.rego | 10 +- .../server/rego/default_policy_module.rego | 12 +- 6 files changed, 315 insertions(+), 80 deletions(-) diff --git a/go.mod b/go.mod index c74fee409..5220eab6e 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,7 @@ require ( go.opentelemetry.io/otel/exporters/prometheus v0.33.0 go.opentelemetry.io/otel/metric v0.33.0 go.opentelemetry.io/otel/sdk/metric v0.33.0 + golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf golang.org/x/net v0.8.0 golang.org/x/sync v0.1.0 golang.org/x/term v0.6.0 @@ -126,7 +127,6 @@ require ( go.opentelemetry.io/otel v1.11.1 // indirect go.opentelemetry.io/otel/sdk v1.11.1 // indirect go.opentelemetry.io/otel/trace v1.11.1 // indirect - golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect golang.org/x/mod v0.8.0 // indirect golang.org/x/text v0.8.0 // indirect diff --git a/management/server/file_store.go b/management/server/file_store.go index 9b4a9f47f..f79179841 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -121,15 +121,25 @@ func restore(file string) (*FileStore, error) { store.PrivateDomain2AccountID[account.Domain] = accountID } - // if no policies are defined, that means we need to migrate Rules to policies - if len(account.Policies) == 0 { + // TODO: policy query generated from the Go template and rule object. + // We need to refactor this part to avoid using templating for policies queries building + // and drop this migration part. + policies := make(map[string]int, len(account.Policies)) + for i, policy := range account.Policies { + policies[policy.ID] = i + } + if account.Policies == nil { account.Policies = make([]*Policy, 0) - for _, rule := range account.Rules { - policy, err := RuleToPolicy(rule) - if err != nil { - log.Errorf("unable to migrate rule to policy: %v", err) - continue - } + } + for _, rule := range account.Rules { + policy, err := RuleToPolicy(rule) + if err != nil { + log.Errorf("unable to migrate rule to policy: %v", err) + continue + } + if i, ok := policies[policy.ID]; ok { + account.Policies[i] = policy + } else { account.Policies = append(account.Policies, policy) } } diff --git a/management/server/policy.go b/management/server/policy.go index 31f6bb655..8a166c25c 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -178,6 +178,9 @@ type FirewallRule struct { // Port of the traffic Port string + + // id for internal purposes + id string } // parseFromRegoResult parses the Rego result to a FirewallRule. @@ -218,39 +221,35 @@ func (f *FirewallRule) parseFromRegoResult(value interface{}) error { f.Action = action f.Port = port + // NOTE: update this id each time when new field added + f.id = peerID + peerIP + direction + action + port + return nil } -// getRegoQuery returns a initialized Rego object with default rule. -func (a *Account) getRegoQuery() (rego.PreparedEvalQuery, error) { - queries := []func(*rego.Rego){ - rego.Query("data.netbird.all"), - rego.Module("netbird", defaultPolicyModule), - } - for i, p := range a.Policies { - if !p.Enabled { - continue - } - queries = append(queries, rego.Module(fmt.Sprintf("netbird-%d", i), p.Query)) - } - return rego.New(queries...).PrepareForEval(context.TODO()) -} - -// getPeersByPolicy returns all peers that given peer has access to. -func (a *Account) getPeersByPolicy(peerID string) ([]*Peer, []*FirewallRule) { +// queryPeersAndFwRulesByRego returns a list associated Peers and firewall rules list for this peer. +func (a *Account) queryPeersAndFwRulesByRego( + peerID string, + queryNumber int, + query string, +) ([]*Peer, []*FirewallRule) { input := map[string]interface{}{ "peer_id": peerID, "peers": a.Peers, "groups": a.Groups, } - query, err := a.getRegoQuery() + stmt, err := rego.New( + rego.Query("data.netbird.all"), + rego.Module("netbird", defaultPolicyModule), + rego.Module(fmt.Sprintf("netbird-%d", queryNumber), query), + ).PrepareForEval(context.TODO()) if err != nil { log.WithError(err).Error("get Rego query") return nil, nil } - evalResult, err := query.Eval( + evalResult, err := stmt.Eval( context.TODO(), rego.EvalInput(input), ) @@ -318,6 +317,33 @@ func (a *Account) getPeersByPolicy(peerID string) ([]*Peer, []*FirewallRule) { return peers, rules } +// getPeersByPolicy returns all peers that given peer has access to. +func (a *Account) getPeersByPolicy(peerID string) (peers []*Peer, rules []*FirewallRule) { + peersSeen := make(map[string]struct{}) + ruleSeen := make(map[string]struct{}) + for i, policy := range a.Policies { + if !policy.Enabled { + continue + } + p, r := a.queryPeersAndFwRulesByRego(peerID, i, policy.Query) + for _, peer := range p { + if _, ok := peersSeen[peer.ID]; ok { + continue + } + peers = append(peers, peer) + peersSeen[peer.ID] = struct{}{} + } + for _, rule := range r { + if _, ok := ruleSeen[rule.id]; ok { + continue + } + rules = append(rules, rule) + ruleSeen[rule.id] = struct{}{} + } + } + return +} + // GetPolicy from the store func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) { unlock := am.Store.AcquireAccountLock(accountID) diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 73663a8fd..39ac44843 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -5,63 +5,268 @@ import ( "testing" "github.com/stretchr/testify/assert" + "golang.org/x/exp/slices" ) func TestAccount_getPeersByPolicy(t *testing.T) { account := &Account{ Peers: map[string]*Peer{ - "peer1": { - ID: "peer1", - IP: net.IPv4(10, 20, 0, 1), + "cfif97at2r9s73au3q00": { + ID: "cfif97at2r9s73au3q00", + IP: net.ParseIP("100.65.14.88"), }, - "peer2": { - ID: "peer2", - IP: net.IPv4(10, 20, 0, 2), + "cfif97at2r9s73au3q0g": { + ID: "cfif97at2r9s73au3q0g", + IP: net.ParseIP("100.65.80.39"), }, - "peer3": { - ID: "peer3", - IP: net.IPv4(10, 20, 0, 3), + "cfif97at2r9s73au3q10": { + ID: "cfif97at2r9s73au3q10", + IP: net.ParseIP("100.65.254.139"), + }, + "cfif97at2r9s73au3q20": { + ID: "cfif97at2r9s73au3q20", + IP: net.ParseIP("100.65.62.5"), + }, + "cfj4tiqt2r9s73dmeun0": { + ID: "cfj4tiqt2r9s73dmeun0", + IP: net.ParseIP("100.65.32.206"), + }, + "cg7h032t2r9s73cg5fk0": { + ID: "cg7h032t2r9s73cg5fk0", + IP: net.ParseIP("100.65.250.202"), + }, + "cgcnkj2t2r9s73cg5vv0": { + ID: "cgcnkj2t2r9s73cg5vv0", + IP: net.ParseIP("100.65.13.186"), + }, + "cgcol4qt2r9s73cg601g": { + ID: "cgcol4qt2r9s73cg601g", + IP: net.ParseIP("100.65.29.55"), }, }, Groups: map[string]*Group{ - "gid1": { - ID: "gid1", - Name: "all", - Peers: []string{"peer1", "peer2", "peer3"}, + "cet9e92t2r9s7383ns20": { + ID: "cet9e92t2r9s7383ns20", + Name: "All", + Peers: []string{ + "cfif97at2r9s73au3q0g", + "cfif97at2r9s73au3q00", + "cfif97at2r9s73au3q20", + "cfif97at2r9s73au3q10", + "cfj4tiqt2r9s73dmeun0", + "cg7h032t2r9s73cg5fk0", + "cgcnkj2t2r9s73cg5vv0", + "cgcol4qt2r9s73cg601g", + }, + }, + "cev90bat2r9s7383o150": { + ID: "cev90bat2r9s7383o150", + Name: "swarm", + Peers: []string{ + "cfif97at2r9s73au3q0g", + "cfif97at2r9s73au3q00", + "cfif97at2r9s73au3q20", + "cfj4tiqt2r9s73dmeun0", + "cgcnkj2t2r9s73cg5vv0", + "cgcol4qt2r9s73cg601g", + }, }, }, Rules: map[string]*Rule{ - "default": { - ID: "default", - Name: "default", - Description: "default", - Disabled: false, - Source: []string{"gid1"}, - Destination: []string{"gid1"}, + "cet9e92t2r9s7383ns2g": { + ID: "cet9e92t2r9s7383ns2g", + Name: "Default", + Description: "This is a default rule that allows connections between all the resources", + Source: []string{ + "cet9e92t2r9s7383ns20", + }, + Destination: []string{ + "cet9e92t2r9s7383ns20", + }, + }, + "cev90bat2r9s7383o15g": { + ID: "cev90bat2r9s7383o15g", + Name: "Swarm", + Description: "", + Source: []string{ + "cev90bat2r9s7383o150", + "cet9e92t2r9s7383ns20", + }, + Destination: []string{ + "cev90bat2r9s7383o150", + }, }, }, } - rule, err := RuleToPolicy(account.Rules["default"]) + rule1, err := RuleToPolicy(account.Rules["cet9e92t2r9s7383ns2g"]) assert.NoError(t, err) - account.Policies = append(account.Policies, rule) + rule2, err := RuleToPolicy(account.Rules["cev90bat2r9s7383o15g"]) + assert.NoError(t, err) - peers, firewallRules := account.getPeersByPolicy("peer1") - assert.Len(t, peers, 2) - assert.Contains(t, peers, account.Peers["peer2"]) - assert.Contains(t, peers, account.Peers["peer3"]) + account.Policies = append(account.Policies, rule1, rule2) - epectedFirewallRules := []*FirewallRule{ - {PeerID: "peer1", PeerIP: "10.20.0.1", Direction: "dst", Action: "accept", Port: ""}, - {PeerID: "peer2", PeerIP: "10.20.0.2", Direction: "dst", Action: "accept", Port: ""}, - {PeerID: "peer3", PeerIP: "10.20.0.3", Direction: "dst", Action: "accept", Port: ""}, - {PeerID: "peer1", PeerIP: "10.20.0.1", Direction: "src", Action: "accept", Port: ""}, - {PeerID: "peer2", PeerIP: "10.20.0.2", Direction: "src", Action: "accept", Port: ""}, - {PeerID: "peer3", PeerIP: "10.20.0.3", Direction: "src", Action: "accept", Port: ""}, - } - assert.Len(t, firewallRules, len(epectedFirewallRules)) - for i := range firewallRules { - assert.Equal(t, firewallRules[i], epectedFirewallRules[i]) - } + t.Run("check that all peers get map", func(t *testing.T) { + for _, p := range account.Peers { + peers, firewallRules := account.getPeersByPolicy(p.ID) + assert.GreaterOrEqual(t, len(peers), 2, "mininum number peers should present") + assert.GreaterOrEqual(t, len(firewallRules), 2, "mininum number of firewall rules should present") + } + }) + + t.Run("check first peer map details", func(t *testing.T) { + peers, firewallRules := account.getPeersByPolicy("cfif97at2r9s73au3q0g") + assert.Len(t, peers, 7) + assert.Contains(t, peers, account.Peers["cfif97at2r9s73au3q00"]) + assert.Contains(t, peers, account.Peers["cfif97at2r9s73au3q10"]) + assert.Contains(t, peers, account.Peers["cfif97at2r9s73au3q20"]) + assert.Contains(t, peers, account.Peers["cfj4tiqt2r9s73dmeun0"]) + assert.Contains(t, peers, account.Peers["cg7h032t2r9s73cg5fk0"]) + + epectedFirewallRules := []*FirewallRule{ + { + PeerID: "cfif97at2r9s73au3q00", + PeerIP: "100.65.14.88", + Direction: "src", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q00100.65.14.88srcaccept", + }, + { + PeerID: "cfif97at2r9s73au3q00", + PeerIP: "100.65.14.88", + Direction: "dst", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q00100.65.14.88dstaccept", + }, + + { + PeerID: "cfif97at2r9s73au3q0g", + PeerIP: "100.65.80.39", + Direction: "dst", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q0g100.65.80.39dstaccept", + }, + { + PeerID: "cfif97at2r9s73au3q0g", + PeerIP: "100.65.80.39", + Direction: "src", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q0g100.65.80.39srcaccept", + }, + + { + PeerID: "cfif97at2r9s73au3q10", + PeerIP: "100.65.254.139", + Direction: "dst", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q10100.65.254.139dstaccept", + }, + { + PeerID: "cfif97at2r9s73au3q10", + PeerIP: "100.65.254.139", + Direction: "src", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q10100.65.254.139srcaccept", + }, + + { + PeerID: "cfif97at2r9s73au3q20", + PeerIP: "100.65.62.5", + Direction: "dst", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q20100.65.62.5dstaccept", + }, + { + PeerID: "cfif97at2r9s73au3q20", + PeerIP: "100.65.62.5", + Direction: "src", + Action: "accept", + Port: "", + id: "cfif97at2r9s73au3q20100.65.62.5srcaccept", + }, + + { + PeerID: "cfj4tiqt2r9s73dmeun0", + PeerIP: "100.65.32.206", + Direction: "dst", + Action: "accept", + Port: "", + id: "cfj4tiqt2r9s73dmeun0100.65.32.206dstaccept", + }, + { + PeerID: "cfj4tiqt2r9s73dmeun0", + PeerIP: "100.65.32.206", + Direction: "src", + Action: "accept", + Port: "", + id: "cfj4tiqt2r9s73dmeun0100.65.32.206srcaccept", + }, + + { + PeerID: "cg7h032t2r9s73cg5fk0", + PeerIP: "100.65.250.202", + Direction: "dst", + Action: "accept", + Port: "", + id: "cg7h032t2r9s73cg5fk0100.65.250.202dstaccept", + }, + { + PeerID: "cg7h032t2r9s73cg5fk0", + PeerIP: "100.65.250.202", + Direction: "src", + Action: "accept", + Port: "", + id: "cg7h032t2r9s73cg5fk0100.65.250.202srcaccept", + }, + + { + PeerID: "cgcnkj2t2r9s73cg5vv0", + PeerIP: "100.65.13.186", + Direction: "dst", + Action: "accept", + Port: "", + id: "cgcnkj2t2r9s73cg5vv0100.65.13.186dstaccept", + }, + { + PeerID: "cgcnkj2t2r9s73cg5vv0", + PeerIP: "100.65.13.186", + Direction: "src", + Action: "accept", + Port: "", + id: "cgcnkj2t2r9s73cg5vv0100.65.13.186srcaccept", + }, + + { + PeerID: "cgcol4qt2r9s73cg601g", + PeerIP: "100.65.29.55", + Direction: "dst", + Action: "accept", + Port: "", + id: "cgcol4qt2r9s73cg601g100.65.29.55dstaccept", + }, + { + PeerID: "cgcol4qt2r9s73cg601g", + PeerIP: "100.65.29.55", + Direction: "src", + Action: "accept", + Port: "", + id: "cgcol4qt2r9s73cg601g100.65.29.55srcaccept", + }, + } + assert.Len(t, firewallRules, len(epectedFirewallRules)) + slices.SortFunc(firewallRules, func(a, b *FirewallRule) bool { + return a.PeerID < b.PeerID + }) + for i := range firewallRules { + assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) + } + }) } diff --git a/management/server/rego/default_policy.rego b/management/server/rego/default_policy.rego index 92e975a02..a1012ae76 100644 --- a/management/server/rego/default_policy.rego +++ b/management/server/rego/default_policy.rego @@ -1,9 +1,9 @@ package netbird all[rule] { - is_peer_in_any_group([{{range $i, $e := .All}}{{if $i}},{{end}}"{{$e}}"{{end}}]) - rule := array.concat( - rules_from_groups([{{range $i, $e := .Destination}}{{if $i}},{{end}}"{{$e}}"{{end}}], "dst", "accept", ""), - rules_from_groups([{{range $i, $e := .Source}}{{if $i}},{{end}}"{{$e}}"{{end}}], "src", "accept", ""), - )[_] + is_peer_in_any_group([{{range $i, $e := .All}}{{if $i}},{{end}}"{{$e}}"{{end}}]) + rule := { + {{range $i, $e := .Destination}}rules_from_group("{{$e}}", "dst", "accept", ""),{{end}} + {{range $i, $e := .Source}}rules_from_group("{{$e}}", "src", "accept", ""),{{end}} + }[_][_] } diff --git a/management/server/rego/default_policy_module.rego b/management/server/rego/default_policy_module.rego index 846e22e21..7411db36a 100644 --- a/management/server/rego/default_policy_module.rego +++ b/management/server/rego/default_policy_module.rego @@ -17,17 +17,11 @@ get_rule(peer_id, direction, action, port) := rule if { } } -# peers_from_group returns a list of peer ids for a given group id -peers_from_group(group_id) := peers if { +# netbird_rules_from_group returns a list of netbird rules for a given group_id +rules_from_group(group_id, direction, action, port) := rules if { group := input.groups[_] group.ID == group_id - peers := [peer | peer := group.Peers[_]] -} - -# netbird_rules_from_groups returns a list of netbird rules for a given list of group names -rules_from_groups(groups, direction, action, port) := rules if { - group_id := groups[_] - rules := [get_rule(peer, direction, action, port) | peer := peers_from_group(group_id)[_]] + rules := [get_rule(peer, direction, action, port) | peer := group.Peers[_]] } # is_peer_in_any_group checks that input peer present at least in one group From 86f9051a3060088f82ffc26caa0798c1a2880a57 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 3 Apr 2023 16:59:13 +0200 Subject: [PATCH 41/50] Fix/connection listener (#777) Fix add/remove connection listener In case we call the RemoveConnListener from Java then we lose the reference from the original instance --- client/android/client.go | 10 ++-- client/internal/peer/notifier.go | 67 ++++++++++++++------------- client/internal/peer/notifier_test.go | 66 ++++++++++++++++++++++++++ client/internal/peer/status.go | 12 ++--- 4 files changed, 111 insertions(+), 44 deletions(-) diff --git a/client/android/client.go b/client/android/client.go index 5e3c0c85a..3a7c2c8dc 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -118,12 +118,12 @@ func (c *Client) PeersList() *PeerInfoArray { return &PeerInfoArray{items: peerInfos} } -// AddConnectionListener add new network connection listener -func (c *Client) AddConnectionListener(listener ConnectionListener) { - c.recorder.AddConnectionListener(listener) +// SetConnectionListener set the network connection listener +func (c *Client) SetConnectionListener(listener ConnectionListener) { + c.recorder.SetConnectionListener(listener) } // RemoveConnectionListener remove connection listener -func (c *Client) RemoveConnectionListener(listener ConnectionListener) { - c.recorder.RemoveConnectionListener(listener) +func (c *Client) RemoveConnectionListener() { + c.recorder.RemoveConnectionListener() } diff --git a/client/internal/peer/notifier.go b/client/internal/peer/notifier.go index 4e618d2f8..b2d324c6c 100644 --- a/client/internal/peer/notifier.go +++ b/client/internal/peer/notifier.go @@ -14,32 +14,31 @@ const ( type notifier struct { serverStateLock sync.Mutex listenersLock sync.Mutex - listeners map[Listener]struct{} + listener Listener currentServerState bool currentClientState bool lastNotification int } func newNotifier() *notifier { - return ¬ifier{ - listeners: make(map[Listener]struct{}), - } + return ¬ifier{} } -func (n *notifier) addListener(listener Listener) { +func (n *notifier) setListener(listener Listener) { n.listenersLock.Lock() defer n.listenersLock.Unlock() n.serverStateLock.Lock() - go n.notifyListener(listener, n.lastNotification) + n.notifyListener(listener, n.lastNotification) n.serverStateLock.Unlock() - n.listeners[listener] = struct{}{} + + n.listener = listener } -func (n *notifier) removeListener(listener Listener) { +func (n *notifier) removeListener() { n.listenersLock.Lock() defer n.listenersLock.Unlock() - delete(n.listeners, listener) + n.listener = nil } func (n *notifier) updateServerStates(mgmState bool, signalState bool) { @@ -64,7 +63,7 @@ func (n *notifier) updateServerStates(mgmState bool, signalState bool) { } n.lastNotification = n.calculateState(newState, n.currentClientState) - go n.notifyAll(n.lastNotification) + n.notify(n.lastNotification) } func (n *notifier) clientStart() { @@ -72,7 +71,7 @@ func (n *notifier) clientStart() { defer n.serverStateLock.Unlock() n.currentClientState = true n.lastNotification = n.calculateState(n.currentServerState, true) - go n.notifyAll(n.lastNotification) + n.notify(n.lastNotification) } func (n *notifier) clientStop() { @@ -80,7 +79,7 @@ func (n *notifier) clientStop() { defer n.serverStateLock.Unlock() n.currentClientState = false n.lastNotification = n.calculateState(n.currentServerState, false) - go n.notifyAll(n.lastNotification) + n.notify(n.lastNotification) } func (n *notifier) clientTearDown() { @@ -88,33 +87,35 @@ func (n *notifier) clientTearDown() { defer n.serverStateLock.Unlock() n.currentClientState = false n.lastNotification = stateDisconnecting - go n.notifyAll(n.lastNotification) + n.notify(n.lastNotification) } func (n *notifier) isServerStateChanged(newState bool) bool { return n.currentServerState != newState } -func (n *notifier) notifyAll(state int) { +func (n *notifier) notify(state int) { n.listenersLock.Lock() defer n.listenersLock.Unlock() - - for l := range n.listeners { - n.notifyListener(l, state) + if n.listener == nil { + return } + n.notifyListener(n.listener, state) } func (n *notifier) notifyListener(l Listener, state int) { - switch state { - case stateDisconnected: - l.OnDisconnected() - case stateConnected: - l.OnConnected() - case stateConnecting: - l.OnConnecting() - case stateDisconnecting: - l.OnDisconnecting() - } + go func() { + switch state { + case stateDisconnected: + l.OnDisconnected() + case stateConnected: + l.OnConnected() + case stateConnecting: + l.OnConnecting() + case stateDisconnecting: + l.OnDisconnecting() + } + }() } func (n *notifier) calculateState(serverState bool, clientState bool) int { @@ -132,17 +133,17 @@ func (n *notifier) calculateState(serverState bool, clientState bool) int { func (n *notifier) peerListChanged(numOfPeers int) { n.listenersLock.Lock() defer n.listenersLock.Unlock() - - for l := range n.listeners { - l.OnPeersListChanged(numOfPeers) + if n.listener == nil { + return } + n.listener.OnPeersListChanged(numOfPeers) } func (n *notifier) localAddressChanged(fqdn, address string) { n.listenersLock.Lock() defer n.listenersLock.Unlock() - - for l := range n.listeners { - l.OnAddressChanged(fqdn, address) + if n.listener == nil { + return } + n.listener.OnAddressChanged(fqdn, address) } diff --git a/client/internal/peer/notifier_test.go b/client/internal/peer/notifier_test.go index f21193e06..a9045ac34 100644 --- a/client/internal/peer/notifier_test.go +++ b/client/internal/peer/notifier_test.go @@ -1,9 +1,48 @@ package peer import ( + "sync" "testing" ) +type mocListener struct { + lastState int + wg sync.WaitGroup + peers int +} + +func (l *mocListener) OnConnected() { + l.lastState = stateConnected + l.wg.Done() +} +func (l *mocListener) OnDisconnected() { + l.lastState = stateDisconnected + l.wg.Done() +} +func (l *mocListener) OnConnecting() { + l.lastState = stateConnecting + l.wg.Done() +} +func (l *mocListener) OnDisconnecting() { + l.lastState = stateDisconnecting + l.wg.Done() +} + +func (l *mocListener) OnAddressChanged(host, addr string) { + +} +func (l *mocListener) OnPeersListChanged(size int) { + l.peers = size +} + +func (l *mocListener) setWaiter() { + l.wg.Add(1) +} + +func (l *mocListener) wait() { + l.wg.Wait() +} + func Test_notifier_serverState(t *testing.T) { type scenario struct { @@ -30,3 +69,30 @@ func Test_notifier_serverState(t *testing.T) { }) } } + +func Test_notifier_SetListener(t *testing.T) { + listener := &mocListener{} + listener.setWaiter() + + n := newNotifier() + n.lastNotification = stateConnecting + n.setListener(listener) + listener.wait() + if listener.lastState != n.lastNotification { + t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification) + } +} + +func Test_notifier_RemoveListener(t *testing.T) { + listener := &mocListener{} + listener.setWaiter() + n := newNotifier() + n.lastNotification = stateConnecting + n.setListener(listener) + n.removeListener() + n.peerListChanged(1) + + if listener.peers != 0 { + t.Errorf("invalid state: %d", listener.peers) + } +} diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 62841d6fc..508131816 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -293,14 +293,14 @@ func (d *Status) ClientTeardown() { d.notifier.clientTearDown() } -// AddConnectionListener add a listener to the notifier -func (d *Status) AddConnectionListener(listener Listener) { - d.notifier.addListener(listener) +// SetConnectionListener set a listener to the notifier +func (d *Status) SetConnectionListener(listener Listener) { + d.notifier.setListener(listener) } -// RemoveConnectionListener remove a listener from the notifier -func (d *Status) RemoveConnectionListener(listener Listener) { - d.notifier.removeListener(listener) +// RemoveConnectionListener remove the listener from the notifier +func (d *Status) RemoveConnectionListener() { + d.notifier.removeListener() } func (d *Status) onConnectionChanged() { From 5993982cca8f94b25d81e9f39ac8437e8a4bbe69 Mon Sep 17 00:00:00 2001 From: Ruakij Date: Tue, 4 Apr 2023 00:21:40 +0200 Subject: [PATCH 42/50] Add disable letsencrypt (#747) Add NETBIRD_DISABLE_LETSENCRYPT support to explicit disable let's encrypt Organize the setup.env.example variables into sections Add traefik example --- infrastructure_files/base.setup.env | 14 ++- infrastructure_files/configure.sh | 26 +++++ infrastructure_files/docker-compose.yml.tmpl | 30 ++++-- .../docker-compose.yml.tmpl.traefik | 99 +++++++++++++++++++ infrastructure_files/management.json.tmpl | 4 +- infrastructure_files/setup.env.example | 18 +++- 6 files changed, 174 insertions(+), 17 deletions(-) create mode 100644 infrastructure_files/docker-compose.yml.tmpl.traefik diff --git a/infrastructure_files/base.setup.env b/infrastructure_files/base.setup.env index 2d74e3a66..e62b02a08 100644 --- a/infrastructure_files/base.setup.env +++ b/infrastructure_files/base.setup.env @@ -7,14 +7,18 @@ NETBIRD_MGMT_API_PORT=33073 # Management API endpoint address, used by the Dashboard NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT # Management Certficate file path. These are generated by the Dashboard container -NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/fullchain.pem" +NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/fullchain.pem" # Management Certficate key file path. -NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/privkey.pem" +NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/privkey.pem" # By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted} -# Turn credentials +# Signal +NETBIRD_SIGNAL_PROTOCOL="http" +NETBIRD_SIGNAL_PORT=10000 + +# Turn credentials # User TURN_USER=self # Password. If empty, the configure.sh will generate one with openssl @@ -61,4 +65,6 @@ export SIGNAL_VOLUMESUFFIX export LETSENCRYPT_VOLUMESUFFIX export NETBIRD_DISABLE_ANONYMOUS_METRICS export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN -export NETBIRD_MGMT_DNS_DOMAIN \ No newline at end of file +export NETBIRD_MGMT_DNS_DOMAIN +export NETBIRD_SIGNAL_PROTOCOL +export NETBIRD_SIGNAL_PORT diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index ed6367171..501098a57 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -121,6 +121,32 @@ if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted" fi +# Check if letsencrypt was disabled +if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]] +then + export NETBIRD_DASHBOARD_ENDPOINT="https://$NETBIRD_DOMAIN:443" + export NETBIRD_SIGNAL_ENDPOINT="https://$NETBIRD_DOMAIN:$NETBIRD_SIGNAL_PORT" + + echo "Letsencrypt was disabled, the Https-endpoints cannot be used anymore" + echo " and a reverse-proxy with Https needs to be placed in front of netbird!" + echo "The following forwards have to be setup:" + echo "- $NETBIRD_DASHBOARD_ENDPOINT -http-> dashboard:80" + echo "- $NETBIRD_MGMT_API_ENDPOINT/api -http-> management:$NETBIRD_MGMT_API_PORT" + echo "- $NETBIRD_MGMT_API_ENDPOINT/management.ManagementService/ -grpc-> management:$NETBIRD_MGMT_API_PORT" + echo "- $NETBIRD_SIGNAL_ENDPOINT/signalexchange.SignalExchange/ -grpc-> signal:80" + echo "You most likely also have to change NETBIRD_MGMT_API_ENDPOINT in base.setup.env and port-mappings in docker-compose.yml.tmpl and rerun this script." + echo " The target of the forwards depends on your setup. Beware of the gRPC protocol instead of http for management and signal!" + echo "You are also free to remove any occurences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" + echo "" + + export NETBIRD_SIGNAL_PROTOCOL="https" + unset NETBIRD_LETSENCRYPT_DOMAIN + unset NETBIRD_MGMT_API_CERT_FILE + unset NETBIRD_MGMT_API_CERT_KEY_FILE +else + export NETBIRD_LETSENCRYPT_DOMAIN="$NETBIRD_DOMAIN" +fi + env | grep NETBIRD envsubst < docker-compose.yml.tmpl > docker-compose.yml diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index 296201710..c8febdea7 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -8,20 +8,25 @@ services: - 80:80 - 443:443 environment: + # Endpoints + - NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT + - NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT + # OIDC - AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE - AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID - AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY - USE_AUTH0=$NETBIRD_USE_AUTH0 - AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES - - NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT - - NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT - - NGINX_SSL_PORT=443 - - LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN - - LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL - AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI - AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI + # SSL + - NGINX_SSL_PORT=443 + # Letsencrypt + - LETSENCRYPT_DOMAIN=$NETBIRD_LETSENCRYPT_DOMAIN + - LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL volumes: - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/ + # Signal signal: image: netbirdio/signal:latest @@ -32,7 +37,8 @@ services: - 10000:80 # # port and command for Let's Encrypt validation # - 443:443 - # command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"] + # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + # Management management: image: netbirdio/management:latest @@ -46,8 +52,15 @@ services: ports: - $NETBIRD_MGMT_API_PORT:443 #API port # # command for Let's Encrypt validation without dashboard container - # command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"] - command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"] + # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--port", "443", + "--log-file", "console", + "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", + "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", + "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN" + ] + # Coturn coturn: image: coturn/coturn @@ -60,6 +73,7 @@ services: network_mode: host command: - -c /etc/turnserver.conf + volumes: $MGMT_VOLUMENAME: $SIGNAL_VOLUMENAME: diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik new file mode 100644 index 000000000..9c1e0fd03 --- /dev/null +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -0,0 +1,99 @@ +version: "3" +services: + #UI dashboard + dashboard: + image: wiretrustee/dashboard:latest + restart: unless-stopped + #ports: + # - 80:80 + # - 443:443 + environment: + # Endpoints + - NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT + - NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT + # OIDC + - AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE + - AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID + - AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY + - USE_AUTH0=$NETBIRD_USE_AUTH0 + - AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES + - AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI + - AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI + # SSL + - NGINX_SSL_PORT=443 + # Letsencrypt + - LETSENCRYPT_DOMAIN=$NETBIRD_LETSENCRYPT_DOMAIN + - LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL + volumes: + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/ + labels: + - traefik.enable=true + - traefik.http.routers.netbird-dashboard.rule=Host(`$NETBIRD_DOMAIN`) + - traefik.http.services.netbird-dashboard.loadbalancer.server.port=80 + + # Signal + signal: + image: netbirdio/signal:latest + restart: unless-stopped + volumes: + - $SIGNAL_VOLUMENAME:/var/lib/netbird + #ports: + # - 10000:80 + # # port and command for Let's Encrypt validation + # - 443:443 + # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + labels: + - traefik.enable=true + - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) + - traefik.http.services.netbird-signal.loadbalancer.server.port=80 + - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c + + # Management + management: + image: netbirdio/management:latest + restart: unless-stopped + depends_on: + - dashboard + volumes: + - $MGMT_VOLUMENAME:/var/lib/netbird + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro + - ./management.json:/etc/netbird/management.json + #ports: + # - $NETBIRD_MGMT_API_PORT:443 #API port + # # command for Let's Encrypt validation without dashboard container + # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--port", "443", + "--log-file", "console", + "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", + "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", + "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN" + ] + labels: + - traefik.enable=true + - traefik.http.routers.netbird-api.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/api`) + - traefik.http.routers.netbird-api.service=netbird-api + - traefik.http.services.netbird-api.loadbalancer.server.port=443 + + - traefik.http.routers.netbird-management.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/management.ManagementService/`) + - traefik.http.routers.netbird-management.service=netbird-management + - traefik.http.services.netbird-management.loadbalancer.server.port=443 + - traefik.http.services.netbird-management.loadbalancer.server.scheme=h2c + + # Coturn + coturn: + image: coturn/coturn + restart: unless-stopped + domainname: $NETBIRD_DOMAIN + volumes: + - ./turnserver.conf:/etc/turnserver.conf:ro + # - ./privkey.pem:/etc/coturn/private/privkey.pem:ro + # - ./cert.pem:/etc/coturn/certs/cert.pem:ro + network_mode: host + command: + - -c /etc/turnserver.conf + +volumes: + $MGMT_VOLUMENAME: + $SIGNAL_VOLUMENAME: + $LETSENCRYPT_VOLUMENAME: diff --git a/infrastructure_files/management.json.tmpl b/infrastructure_files/management.json.tmpl index f3b08101c..cb02c8f24 100644 --- a/infrastructure_files/management.json.tmpl +++ b/infrastructure_files/management.json.tmpl @@ -21,8 +21,8 @@ "TimeBasedCredentials": false }, "Signal": { - "Proto": "http", - "URI": "$NETBIRD_DOMAIN:10000", + "Proto": "$NETBIRD_SIGNAL_PROTOCOL", + "URI": "$NETBIRD_DOMAIN:$NETBIRD_SIGNAL_PORT", "Username": "", "Password": null }, diff --git a/infrastructure_files/setup.env.example b/infrastructure_files/setup.env.example index 09f407225..9703d3e4c 100644 --- a/infrastructure_files/setup.env.example +++ b/infrastructure_files/setup.env.example @@ -2,7 +2,11 @@ ## # Dashboard domain. e.g. app.mydomain.com NETBIRD_DOMAIN="" -# OIDC configuration e.g., https://example.eu.auth0.com/.well-known/openid-configuration + +# ------------------------------------------- +# OIDC +# e.g., https://example.eu.auth0.com/.well-known/openid-configuration +# ------------------------------------------- NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="" NETBIRD_AUTH_AUDIENCE="" # e.g. netbird-client @@ -13,13 +17,21 @@ NETBIRD_AUTH_CLIENT_ID="" NETBIRD_USE_AUTH0="false" NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none" NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID="" -# e.g. hello@mydomain.com -NETBIRD_LETSENCRYPT_EMAIL="" + # if your IDP provider doesn't support fragmented URIs, configure custom # redirect and silent redirect URIs, these will be concatenated into your NETBIRD_DOMAIN domain. # NETBIRD_AUTH_REDIRECT_URI="/peers" # NETBIRD_AUTH_SILENT_REDIRECT_URI="/add-peers" +# ------------------------------------------- +# Letsencrypt +# ------------------------------------------- +# Disable letsencrypt +# if disabled, cannot use HTTPS anymore and requires setting up a reverse-proxy to do it instead +NETBIRD_DISABLE_LETSENCRYPT=false +# e.g. hello@mydomain.com +NETBIRD_LETSENCRYPT_EMAIL="" + # Disable anonymous metrics collection, see more information at https://netbird.io/docs/FAQ/metrics-collection NETBIRD_DISABLE_ANONYMOUS_METRICS=false # DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted From 18098e7a7daadacf69c2f390edb724e9e8ea30bd Mon Sep 17 00:00:00 2001 From: Bethuel Date: Tue, 4 Apr 2023 01:35:54 +0300 Subject: [PATCH 43/50] Add single line installer (#775) detect OS package manager If a supported package manager is not available, use binary installation Check if desktop environment is available Skip installing the UI client if SKIP_UI_APP is set to true added tests for Ubuntu and macOS tests --- .github/workflows/install-test-darwin.yml | 58 +++++ .github/workflows/install-test-linux.yml | 36 +++ release_files/install.sh | 287 ++++++++++++++++++++++ 3 files changed, 381 insertions(+) create mode 100644 .github/workflows/install-test-darwin.yml create mode 100644 .github/workflows/install-test-linux.yml create mode 100644 release_files/install.sh diff --git a/.github/workflows/install-test-darwin.yml b/.github/workflows/install-test-darwin.yml new file mode 100644 index 000000000..cdf0cae5a --- /dev/null +++ b/.github/workflows/install-test-darwin.yml @@ -0,0 +1,58 @@ +name: Test installation Darwin + +on: + push: + branches: + - main + pull_request: + paths: + - "release_files/install.sh" + +jobs: + install-cli-only: + runs-on: macos-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Rename brew package + if: ${{ matrix.check_bin_install }} + run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak + + - name: Run install script + run: | + sh ./release_files/install.sh + env: + SKIP_UI_APP: true + + - name: Run tests + run: | + if ! command -v netbird &> /dev/null; then + echo "Error: netbird is not installed" + exit 1 + fi + install-all: + runs-on: macos-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Rename brew package + if: ${{ matrix.check_bin_install }} + run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak + + - name: Run install script + run: | + sh ./release_files/install.sh + + - name: Run tests + run: | + if ! command -v netbird &> /dev/null; then + echo "Error: netbird is not installed" + exit 1 + fi + + if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then + echo "Error: NetBird UI is not installed" + exit 1 + fi diff --git a/.github/workflows/install-test-linux.yml b/.github/workflows/install-test-linux.yml new file mode 100644 index 000000000..d4246881c --- /dev/null +++ b/.github/workflows/install-test-linux.yml @@ -0,0 +1,36 @@ +name: Test installation Linux + +on: + push: + branches: + - main + pull_request: + paths: + - "release_files/install.sh" + +jobs: + install-cli-only: + runs-on: ubuntu-latest + strategy: + matrix: + check_bin_install: [true, false] + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Rename apt package + if: ${{ matrix.check_bin_install }} + run: | + sudo mv /usr/bin/apt /usr/bin/apt.bak + sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak + + - name: Run install script + run: | + sh ./release_files/install.sh + + - name: Run tests + run: | + if ! command -v netbird &> /dev/null; then + echo "Error: netbird is not installed" + exit 1 + fi diff --git a/release_files/install.sh b/release_files/install.sh new file mode 100644 index 000000000..bb052a310 --- /dev/null +++ b/release_files/install.sh @@ -0,0 +1,287 @@ +#!/bin/sh +# This code is based on the netbird-installer contribution by physk on GitHub. +# Source: https://github.com/physk/netbird-installer +set -e + +OWNER="netbirdio" +REPO="netbird" +CLI_APP="netbird" +UI_APP="netbird-ui" + +# Set default variable +OS_NAME="" +OS_TYPE="" +ARCH="$(uname -m)" +PACKAGE_MANAGER="" +INSTALL_DIR="" + +get_latest_release() { + curl -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/latest" \ + | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' +} + +download_release_binary() { + VERSION=$(get_latest_release) + BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" + BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" + + # for Darwin, download the signed Netbird-UI + if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then + BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}_signed.zip" + fi + + BINARY_NAME="$1_${BINARY_BASE_NAME}" + DOWNLOAD_URL="${BASE_URL}/${VERSION}/${BINARY_NAME}" + + echo "Installing $1 from $DOWNLOAD_URL" + cd /tmp && curl -LO "$DOWNLOAD_URL" + + if [ "$OS_TYPE" = "darwin" ] && [ "$1" = "$UI_APP" ]; then + INSTALL_DIR="/Applications/NetBird UI.app" + + # Unzip the app and move to INSTALL_DIR + unzip -q -o "$BINARY_NAME" + mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR" + else + tar -xzvf "$BINARY_NAME" + sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR" + fi +} + +add_apt_repo() { + sudo apt-get update + sudo apt-get install ca-certificates gnupg -y + + curl -sSL https://pkgs.wiretrustee.com/debian/public.key \ + | gpg --dearmor --output /usr/share/keyrings/wiretrustee-archive-keyring.gpg + + APT_REPO="deb [signed-by=/usr/share/keyrings/wiretrustee-archive-keyring.gpg] https://pkgs.wiretrustee.com/debian stable main" + echo "$APT_REPO" | sudo tee /etc/apt/sources.list.d/wiretrustee.list + + sudo apt-get update +} + +add_rpm_repo() { +cat <<-EOF | sudo tee /etc/yum.repos.d/netbird.repo +[Netbird] +name=Netbird +baseurl=https://pkgs.netbird.io/yum/ +enabled=1 +gpgcheck=0 +gpgkey=https://pkgs.netbird.io/yum/repodata/repomd.xml.key +repo_gpgcheck=1 +EOF +} + +add_aur_repo() { + INSTALL_PKGS="git base-devel go" + REMOVE_PKGS="" + + # Check if dependencies are installed + for PKG in $INSTALL_PKGS; do + if ! pacman -Q "$PKG" > /dev/null 2>&1; then + # Install missing package(s) + sudo pacman -S "$PKG" --noconfirm + + # Add installed package for clean up later + REMOVE_PKGS="$REMOVE_PKGS $PKG" + fi + done + + # Build package from AUR + cd /tmp && git clone https://aur.archlinux.org/netbird.git + cd netbird && makepkg -sri --noconfirm + + if ! $SKIP_UI_APP; then + cd /tmp && git clone https://aur.archlinux.org/netbird-ui.git + cd netbird-ui && makepkg -sri --noconfirm + fi + + # Clean up the installed packages + sudo pacman -Rs "$REMOVE_PKGS" --noconfirm +} + +install_native_binaries() { + # Checks for supported architecture + case "$ARCH" in + x86_64|amd64) + ARCH="amd64" + ;; + i?86|x86) + ARCH="386" + ;; + aarch64|arm64) + ARCH="arm64" + ;; + *) + echo "Architecture ${ARCH} not supported" + exit 2 + ;; + esac + + # download and copy binaries to INSTALL_DIR + download_release_binary "$CLI_APP" + if ! $SKIP_UI_APP; then + download_release_binary "$UI_APP" + fi +} + +install_netbird() { + # Check if netbird CLI is installed + if [ -x "$(command -v netbird)" ]; then + if netbird status > /dev/null 2>&1; then + echo "Netbird service is running, please stop it before proceeding" + fi + + echo "Netbird seems to be installed already, please remove it before proceeding" + exit 1 + fi + + # Checks if SKIP_UI_APP env is set + if [ -z "$SKIP_UI_APP" ]; then + SKIP_UI_APP=false + else + if $SKIP_UI_APP; then + echo "SKIP_UI_APP has been set to true in the environment" + echo "Netbird UI installation will be omitted based on your preference" + fi + fi + + # Identify OS name and default package manager + if type uname >/dev/null 2>&1; then + case "$(uname)" in + Linux) + OS_NAME="$(. /etc/os-release && echo "$ID")" + OS_TYPE="linux" + INSTALL_DIR="/usr/bin" + + # Allow netbird UI installation for x64 arch only + if [ "$ARCH" != "amd64" ] && [ "$ARCH" != "arm64" ] \ + && [ "$ARCH" != "x86_64" ];then + SKIP_UI_APP=true + echo "Netbird UI installation will be omitted as $ARCH is not a compactible architecture" + fi + + # Allow netbird UI installation for linux running desktop enviroment + if [ -z "$XDG_CURRENT_DESKTOP" ];then + SKIP_UI_APP=true + echo "Netbird UI installation will be omitted as Linux does not run desktop environment" + fi + + # Check the availability of a compactible package manager + if [ -x "$(command -v apt)" ]; then + PACKAGE_MANAGER="apt" + echo "The installation will be performed using apt package manager" + fi + if [ -x "$(command -v yum)" ]; then + PACKAGE_MANAGER="yum" + echo "The installation will be performed using yum package manager" + fi + if [ -x "$(command -v dnf)" ]; then + PACKAGE_MANAGER="dnf" + echo "The installation will be performed using dnf package manager" + fi + if [ -x "$(command -v pacman)" ]; then + PACKAGE_MANAGER="pacman" + echo "The installation will be performed using pacman package manager" + fi + ;; + Darwin) + OS_NAME="macos" + OS_TYPE="darwin" + INSTALL_DIR="/usr/local/bin" + + # Check the availability of a compatible package manager + if [ -x "$(command -v brew)" ]; then + PACKAGE_MANAGER="brew" + echo "The installation will be performed using brew package manager" + fi + ;; + esac + fi + + # Run the installation, if a desktop environment is not detected + # only the CLI will be installed + case "$PACKAGE_MANAGER" in + apt) + add_apt_repo + sudo apt-get install netbird -y + + if ! $SKIP_UI_APP; then + sudo apt-get install netbird-ui -y + fi + ;; + yum) + add_rpm_repo + sudo yum -y install netbird + if ! $SKIP_UI_APP; then + sudo yum -y install netbird-ui + fi + ;; + dnf) + add_rpm_repo + sudo dnf -y install dnf-plugin-config-manager + sudo dnf config-manager --add-repo /etc/yum.repos.d/netbird.repo + sudo dnf -y install netbird + + if ! $SKIP_UI_APP; then + sudo dnf -y install netbird-ui + fi + ;; + pacman) + sudo pacman -Syy + add_aur_repo + ;; + brew) + # Remove Wiretrustee if it had been installed using Homebrew before + if brew ls --versions wiretrustee >/dev/null 2>&1; then + echo "Removing existing wiretrustee client" + + # Stop and uninstall daemon service: + wiretrustee service stop + wiretrustee service uninstall + + # Unlik the app + brew unlink wiretrustee + fi + + brew install netbirdio/tap/netbird + if ! $SKIP_UI_APP; then + brew install --cask netbirdio/tap/netbird-ui + fi + ;; + *) + if [ "$OS_NAME" = "nixos" ];then + echo "Please add Netbird to your NixOS configuration.nix directly:" + echo + echo "services.netbird.enable = true;" + + if ! $SKIP_UI_APP; then + echo "environment.systemPackages = [ pkgs.netbird-ui ];" + fi + + echo "Build and apply new configuration:" + echo + echo "sudo nixos-rebuild switch" + exit 0 + fi + + install_native_binaries + ;; + esac + + # Load and start netbird service + if ! sudo netbird service install 2>&1; then + echo "Netbird service has already been loaded" + fi + if ! sudo netbird service start 2>&1; then + echo "Netbird service has already been started" + fi + + + echo "Installation has been finished. To connect, you need to run NetBird by executing the following command:" + echo "" + echo "sudo netbird up" +} + +install_netbird \ No newline at end of file From 109481e26dd18a1648925b5e8a097def7c7e8a94 Mon Sep 17 00:00:00 2001 From: Bethuel Date: Tue, 4 Apr 2023 15:26:17 +0300 Subject: [PATCH 44/50] Use first available package manager (#782) --- release_files/install.sh | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/release_files/install.sh b/release_files/install.sh index bb052a310..fda7ea56e 100644 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -53,7 +53,7 @@ add_apt_repo() { sudo apt-get install ca-certificates gnupg -y curl -sSL https://pkgs.wiretrustee.com/debian/public.key \ - | gpg --dearmor --output /usr/share/keyrings/wiretrustee-archive-keyring.gpg + | sudo gpg --dearmor --output /usr/share/keyrings/wiretrustee-archive-keyring.gpg APT_REPO="deb [signed-by=/usr/share/keyrings/wiretrustee-archive-keyring.gpg] https://pkgs.wiretrustee.com/debian stable main" echo "$APT_REPO" | sudo tee /etc/apt/sources.list.d/wiretrustee.list @@ -172,16 +172,13 @@ install_netbird() { if [ -x "$(command -v apt)" ]; then PACKAGE_MANAGER="apt" echo "The installation will be performed using apt package manager" - fi - if [ -x "$(command -v yum)" ]; then - PACKAGE_MANAGER="yum" - echo "The installation will be performed using yum package manager" - fi - if [ -x "$(command -v dnf)" ]; then + elif [ -x "$(command -v dnf)" ]; then PACKAGE_MANAGER="dnf" echo "The installation will be performed using dnf package manager" - fi - if [ -x "$(command -v pacman)" ]; then + elif [ -x "$(command -v yum)" ]; then + PACKAGE_MANAGER="yum" + echo "The installation will be performed using yum package manager" + elif [ -x "$(command -v pacman)" ]; then PACKAGE_MANAGER="pacman" echo "The installation will be performed using pacman package manager" fi From f14f34cf2bc6041030372c817dc797207edc7006 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 4 Apr 2023 15:56:02 +0200 Subject: [PATCH 45/50] Add token source and device flow audience variables (#780) Supporting new dashboard option to configure a source token. Adding configuration support for setting a different audience for device authorization flow. fix custom id claim variable --- .github/workflows/test-docker-compose-linux.yml | 8 ++++++++ infrastructure_files/base.setup.env | 5 +++++ infrastructure_files/docker-compose.yml.tmpl | 3 ++- infrastructure_files/management.json.tmpl | 2 +- infrastructure_files/setup.env.example | 7 ++++++- infrastructure_files/tests/setup.env | 6 +++++- 6 files changed, 27 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test-docker-compose-linux.yml b/.github/workflows/test-docker-compose-linux.yml index d681dd89c..c28e94a4f 100644 --- a/.github/workflows/test-docker-compose-linux.yml +++ b/.github/workflows/test-docker-compose-linux.yml @@ -59,6 +59,10 @@ jobs: CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code CI_NETBIRD_AUTH_REDIRECT_URI: "/peers" + CI_NETBIRD_TOKEN_SOURCE: "idToken" + CI_NETBIRD_AUTH_USER_ID_CLAIM: "email" + CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE: "super" + run: | grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY @@ -68,6 +72,10 @@ jobs: grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073" grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$' + grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$' + grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE + grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM + grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE - name: run docker compose up working-directory: infrastructure_files diff --git a/infrastructure_files/base.setup.env b/infrastructure_files/base.setup.env index e62b02a08..521c0d332 100644 --- a/infrastructure_files/base.setup.env +++ b/infrastructure_files/base.setup.env @@ -36,6 +36,8 @@ LETSENCRYPT_VOLUMESUFFIX="letsencrypt" NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none" NETBIRD_DISABLE_ANONYMOUS_METRICS=${NETBIRD_DISABLE_ANONYMOUS_METRICS:-false} +NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE=${NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE:-$NETBIRD_AUTH_AUDIENCE} +NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken} # exports export NETBIRD_DOMAIN @@ -68,3 +70,6 @@ export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN export NETBIRD_MGMT_DNS_DOMAIN export NETBIRD_SIGNAL_PROTOCOL export NETBIRD_SIGNAL_PORT +export NETBIRD_AUTH_USER_ID_CLAIM +export NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE +export NETBIRD_TOKEN_SOURCE \ No newline at end of file diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index c8febdea7..af7f1af00 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -19,6 +19,7 @@ services: - AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES - AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI - AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI + - NETBIRD_TOKEN_SOURCE=$NETBIRD_TOKEN_SOURCE # SSL - NGINX_SSL_PORT=443 # Letsencrypt @@ -60,7 +61,7 @@ services: "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN" ] - + # Coturn coturn: image: coturn/coturn diff --git a/infrastructure_files/management.json.tmpl b/infrastructure_files/management.json.tmpl index cb02c8f24..19dcff898 100644 --- a/infrastructure_files/management.json.tmpl +++ b/infrastructure_files/management.json.tmpl @@ -43,7 +43,7 @@ "DeviceAuthorizationFlow": { "Provider": "$NETBIRD_AUTH_DEVICE_AUTH_PROVIDER", "ProviderConfig": { - "Audience": "$NETBIRD_AUTH_AUDIENCE", + "Audience": "$NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE", "Domain": "$NETBIRD_AUTH0_DOMAIN", "ClientID": "$NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID", "TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT", diff --git a/infrastructure_files/setup.env.example b/infrastructure_files/setup.env.example index 9703d3e4c..324174757 100644 --- a/infrastructure_files/setup.env.example +++ b/infrastructure_files/setup.env.example @@ -17,11 +17,16 @@ NETBIRD_AUTH_CLIENT_ID="" NETBIRD_USE_AUTH0="false" NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none" NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID="" +# Some IDPs requires different audience for device authorization flow, you can customize here +NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE # if your IDP provider doesn't support fragmented URIs, configure custom # redirect and silent redirect URIs, these will be concatenated into your NETBIRD_DOMAIN domain. # NETBIRD_AUTH_REDIRECT_URI="/peers" # NETBIRD_AUTH_SILENT_REDIRECT_URI="/add-peers" +# Updates the preference to use id tokens instead of access token on dashboard +# Okta and Gitlab IDPs can benefit from this +# NETBIRD_TOKEN_SOURCE="idToken" # ------------------------------------------- # Letsencrypt @@ -35,4 +40,4 @@ NETBIRD_LETSENCRYPT_EMAIL="" # Disable anonymous metrics collection, see more information at https://netbird.io/docs/FAQ/metrics-collection NETBIRD_DISABLE_ANONYMOUS_METRICS=false # DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted -NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted +NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted \ No newline at end of file diff --git a/infrastructure_files/tests/setup.env b/infrastructure_files/tests/setup.env index cdb5e5c6b..09164a135 100644 --- a/infrastructure_files/tests/setup.env +++ b/infrastructure_files/tests/setup.env @@ -11,4 +11,8 @@ NETBIRD_USE_AUTH0=$CI_NETBIRD_USE_AUTH0 NETBIRD_AUTH_AUDIENCE=$CI_NETBIRD_AUTH_AUDIENCE # e.g. hello@mydomain.com NETBIRD_LETSENCRYPT_EMAIL="" -NETBIRD_AUTH_REDIRECT_URI="/peers" \ No newline at end of file +NETBIRD_AUTH_REDIRECT_URI="/peers" +NETBIRD_DISABLE_LETSENCRYPT=true +NETBIRD_TOKEN_SOURCE="idToken" +NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE="super" +NETBIRD_AUTH_USER_ID_CLAIM="email" \ No newline at end of file From fe1ea4a2d0fb8e7e7175d6623ee7b120428a2bc8 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 4 Apr 2023 16:40:56 +0200 Subject: [PATCH 46/50] Check multiple audience values (#781) Some IDP use different audience for different clients. This update checks HTTP and Device authorization flow audience values. --------- Co-authored-by: Givi Khojanashvili --- management/cmd/management.go | 2 +- management/server/config.go | 10 ++++++++++ management/server/grpcserver.go | 2 +- management/server/jwtclaims/jwtValidator.go | 10 ++++++++-- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/management/cmd/management.go b/management/cmd/management.go index 620a89f16..38535462f 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -184,7 +184,7 @@ var ( jwtValidator, err := jwtclaims.NewJWTValidator( config.HttpConfig.AuthIssuer, - config.HttpConfig.AuthAudience, + config.GetAuthAudiences(), config.HttpConfig.AuthKeysLocation, ) if err != nil { diff --git a/management/server/config.go b/management/server/config.go index 6a428c83b..f8d7d8db8 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -39,6 +39,16 @@ type Config struct { DeviceAuthorizationFlow *DeviceAuthorizationFlow } +// GetAuthAudiences returns the audience from the http config and device authorization flow config +func (c Config) GetAuthAudiences() []string { + audiences := []string{c.HttpConfig.AuthAudience} + + if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" { + audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience) + } + + return audiences +} // TURNConfig is a config of the TURNCredentialsManager type TURNConfig struct { TimeBasedCredentials bool diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 0c8dad246..e43c767c3 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -51,7 +51,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { jwtValidator, err = jwtclaims.NewJWTValidator( config.HttpConfig.AuthIssuer, - config.HttpConfig.AuthAudience, + config.GetAuthAudiences(), config.HttpConfig.AuthKeysLocation) if err != nil { return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err) diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index ee9513c57..147f8f2eb 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -64,7 +64,7 @@ type JWTValidator struct { } // NewJWTValidator constructor -func NewJWTValidator(issuer string, audience string, keysLocation string) (*JWTValidator, error) { +func NewJWTValidator(issuer string, audienceList []string, keysLocation string) (*JWTValidator, error) { keys, err := getPemKeys(keysLocation) if err != nil { return nil, err @@ -73,7 +73,13 @@ func NewJWTValidator(issuer string, audience string, keysLocation string) (*JWTV options := Options{ ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { // Verify 'aud' claim - checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) + var checkAud bool + for _, audience := range audienceList { + checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) + if checkAud { + break + } + } if !checkAud { return token, errors.New("invalid audience") } From 2be1a82f4a62516471606a3f96c5408080df2978 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 5 Apr 2023 11:39:22 +0200 Subject: [PATCH 47/50] Configurable port defaults from setup.env Allow configuring management and signal ports from setup.env Allow configuring Coturn range from setup.env --- infrastructure_files/base.setup.env | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/infrastructure_files/base.setup.env b/infrastructure_files/base.setup.env index 521c0d332..8fa58ffc3 100644 --- a/infrastructure_files/base.setup.env +++ b/infrastructure_files/base.setup.env @@ -3,7 +3,7 @@ # Management API # Management API port -NETBIRD_MGMT_API_PORT=33073 +NETBIRD_MGMT_API_PORT=${NETBIRD_MGMT_API_PORT:-33073} # Management API endpoint address, used by the Dashboard NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT # Management Certficate file path. These are generated by the Dashboard container @@ -16,7 +16,7 @@ NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted} # Signal NETBIRD_SIGNAL_PROTOCOL="http" -NETBIRD_SIGNAL_PORT=10000 +NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-10000} # Turn credentials # User @@ -24,9 +24,9 @@ TURN_USER=self # Password. If empty, the configure.sh will generate one with openssl TURN_PASSWORD= # Min port -TURN_MIN_PORT=49152 +TURN_MIN_PORT=${TURN_MIN_PORT:-49152} # Max port -TURN_MAX_PORT=65535 +TURN_MAX_PORT=${TURN_MAX_PORT:-65535} VOLUME_PREFIX="netbird-" MGMT_VOLUMESUFFIX="mgmt" From ea88ec6d27267ddb1c6f16db27a64d16eaeb1c5d Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 5 Apr 2023 11:42:14 +0200 Subject: [PATCH 48/50] Roolback configurable port defaults from setup.env --- infrastructure_files/base.setup.env | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/infrastructure_files/base.setup.env b/infrastructure_files/base.setup.env index 8fa58ffc3..521c0d332 100644 --- a/infrastructure_files/base.setup.env +++ b/infrastructure_files/base.setup.env @@ -3,7 +3,7 @@ # Management API # Management API port -NETBIRD_MGMT_API_PORT=${NETBIRD_MGMT_API_PORT:-33073} +NETBIRD_MGMT_API_PORT=33073 # Management API endpoint address, used by the Dashboard NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT # Management Certficate file path. These are generated by the Dashboard container @@ -16,7 +16,7 @@ NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted} # Signal NETBIRD_SIGNAL_PROTOCOL="http" -NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-10000} +NETBIRD_SIGNAL_PORT=10000 # Turn credentials # User @@ -24,9 +24,9 @@ TURN_USER=self # Password. If empty, the configure.sh will generate one with openssl TURN_PASSWORD= # Min port -TURN_MIN_PORT=${TURN_MIN_PORT:-49152} +TURN_MIN_PORT=49152 # Max port -TURN_MAX_PORT=${TURN_MAX_PORT:-65535} +TURN_MAX_PORT=65535 VOLUME_PREFIX="netbird-" MGMT_VOLUMESUFFIX="mgmt" From e903522f8cca126bcf5dc5ce958a2b9d05340682 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 5 Apr 2023 15:22:06 +0200 Subject: [PATCH 49/50] Configurable port defaults from setup.env (#783) Allow configuring management and signal ports from setup.env Allow configuring Coturn range from setup.env --- infrastructure_files/base.setup.env | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/infrastructure_files/base.setup.env b/infrastructure_files/base.setup.env index 521c0d332..8fa58ffc3 100644 --- a/infrastructure_files/base.setup.env +++ b/infrastructure_files/base.setup.env @@ -3,7 +3,7 @@ # Management API # Management API port -NETBIRD_MGMT_API_PORT=33073 +NETBIRD_MGMT_API_PORT=${NETBIRD_MGMT_API_PORT:-33073} # Management API endpoint address, used by the Dashboard NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT # Management Certficate file path. These are generated by the Dashboard container @@ -16,7 +16,7 @@ NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted} # Signal NETBIRD_SIGNAL_PROTOCOL="http" -NETBIRD_SIGNAL_PORT=10000 +NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-10000} # Turn credentials # User @@ -24,9 +24,9 @@ TURN_USER=self # Password. If empty, the configure.sh will generate one with openssl TURN_PASSWORD= # Min port -TURN_MIN_PORT=49152 +TURN_MIN_PORT=${TURN_MIN_PORT:-49152} # Max port -TURN_MAX_PORT=65535 +TURN_MAX_PORT=${TURN_MAX_PORT:-65535} VOLUME_PREFIX="netbird-" MGMT_VOLUMESUFFIX="mgmt" From 32b345991a38a4456f88ab045b67d4cbcf90620c Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 5 Apr 2023 17:46:34 +0200 Subject: [PATCH 50/50] Support remote scope and use id token configuration (#784) Some IDP requires different scope requests and issue access tokens for different purposes This change allow for remote configurable scopes and the use of ID token --- client/android/login.go | 9 +- client/cmd/login.go | 9 +- client/internal/device_auth.go | 14 +++ client/internal/oauth.go | 76 +++++++------ client/internal/oauth_test.go | 35 +++--- client/server/server.go | 9 +- management/cmd/management.go | 4 + management/proto/management.pb.go | 177 +++++++++++++++++------------- management/proto/management.proto | 4 + management/server/config.go | 10 ++ management/server/grpcserver.go | 2 + 11 files changed, 197 insertions(+), 152 deletions(-) diff --git a/client/android/login.go b/client/android/login.go index 0c11c0cce..518942cb6 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -166,7 +166,7 @@ func (a *Auth) login(urlOpener URLOpener) error { if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } - jwtToken = tokenInfo.AccessToken + jwtToken = tokenInfo.GetTokenToUse() } err = a.withBackOff(a.ctx, func() error { @@ -199,12 +199,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo, } } - hostedClient := internal.NewHostedDeviceFlow( - providerConfig.ProviderConfig.Audience, - providerConfig.ProviderConfig.ClientID, - providerConfig.ProviderConfig.TokenEndpoint, - providerConfig.ProviderConfig.DeviceAuthEndpoint, - ) + hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig) flowInfo, err := hostedClient.RequestDeviceCode(context.TODO()) if err != nil { diff --git a/client/cmd/login.go b/client/cmd/login.go index 13b4b335c..92d69b6ee 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -135,7 +135,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } - jwtToken = tokenInfo.AccessToken + jwtToken = tokenInfo.GetTokenToUse() } err = WithBackOff(func() error { @@ -172,12 +172,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int } } - hostedClient := internal.NewHostedDeviceFlow( - providerConfig.ProviderConfig.Audience, - providerConfig.ProviderConfig.ClientID, - providerConfig.ProviderConfig.TokenEndpoint, - providerConfig.ProviderConfig.DeviceAuthEndpoint, - ) + hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig) flowInfo, err := hostedClient.RequestDeviceCode(context.TODO()) if err != nil { diff --git a/client/internal/device_auth.go b/client/internal/device_auth.go index d2396242b..0273bb8e4 100644 --- a/client/internal/device_auth.go +++ b/client/internal/device_auth.go @@ -34,6 +34,10 @@ type ProviderConfig struct { TokenEndpoint string // DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code DeviceAuthEndpoint string + // Scopes provides the scopes to be included in the token request + Scope string + // UseIDToken indicates if the id token should be used for authentication + UseIDToken bool } // GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it @@ -91,9 +95,16 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain, TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(), DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(), + Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(), + UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(), }, } + // keep compatibility with older management versions + if deviceAuthorizationFlow.ProviderConfig.Scope == "" { + deviceAuthorizationFlow.ProviderConfig.Scope = "openid" + } + err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig) if err != nil { return DeviceAuthorizationFlow{}, err @@ -116,5 +127,8 @@ func isProviderConfigValid(config ProviderConfig) error { if config.DeviceAuthEndpoint == "" { return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint") } + if config.Scope == "" { + return fmt.Errorf(errorMSGFormat, "Device Auth Scopes") + } return nil } diff --git a/client/internal/oauth.go b/client/internal/oauth.go index ae327a620..2d237925d 100644 --- a/client/internal/oauth.go +++ b/client/internal/oauth.go @@ -35,15 +35,6 @@ type DeviceAuthInfo struct { Interval int `json:"interval"` } -// TokenInfo holds information of issued access token -type TokenInfo struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - IDToken string `json:"id_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` -} - // HostedGrantType grant type for device flow on Hosted const ( HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code" @@ -52,16 +43,7 @@ const ( // Hosted client type Hosted struct { - // Hosted API Audience for validation - Audience string - // Hosted Native application client id - ClientID string - // Hosted Native application request scope - Scope string - // TokenEndpoint to request access token - TokenEndpoint string - // DeviceAuthEndpoint to request device authorization code - DeviceAuthEndpoint string + providerConfig ProviderConfig HTTPClient HTTPClient } @@ -70,7 +52,7 @@ type Hosted struct { type RequestDeviceCodePayload struct { Audience string `json:"audience"` ClientID string `json:"client_id"` - Scope string `json:"scope"` + Scope string `json:"scope"` } // TokenRequestPayload used for requesting the auth0 token @@ -93,8 +75,26 @@ type Claims struct { Audience interface{} `json:"aud"` } +// TokenInfo holds information of issued access token +type TokenInfo struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + UseIDToken bool `json:"-"` +} + +// GetTokenToUse returns either the access or id token based on UseIDToken field +func (t TokenInfo) GetTokenToUse() string { + if t.UseIDToken { + return t.IDToken + } + return t.AccessToken +} + // NewHostedDeviceFlow returns an Hosted OAuth client -func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string, deviceAuthEndpoint string) *Hosted { +func NewHostedDeviceFlow(config ProviderConfig) *Hosted { httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport.MaxIdleConns = 5 @@ -104,27 +104,23 @@ func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string, } return &Hosted{ - Audience: audience, - ClientID: clientID, - Scope: "openid", - TokenEndpoint: tokenEndpoint, - HTTPClient: httpClient, - DeviceAuthEndpoint: deviceAuthEndpoint, + providerConfig: config, + HTTPClient: httpClient, } } // GetClientID returns the provider client id func (h *Hosted) GetClientID(ctx context.Context) string { - return h.ClientID + return h.providerConfig.ClientID } // RequestDeviceCode requests a device code login flow information from Hosted func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) { form := url.Values{} - form.Add("client_id", h.ClientID) - form.Add("audience", h.Audience) - form.Add("scope", h.Scope) - req, err := http.NewRequest("POST", h.DeviceAuthEndpoint, + form.Add("client_id", h.providerConfig.ClientID) + form.Add("audience", h.providerConfig.Audience) + form.Add("scope", h.providerConfig.Scope) + req, err := http.NewRequest("POST", h.providerConfig.DeviceAuthEndpoint, strings.NewReader(form.Encode())) if err != nil { return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err) @@ -157,10 +153,10 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) { form := url.Values{} - form.Add("client_id", h.ClientID) + form.Add("client_id", h.providerConfig.ClientID) form.Add("grant_type", HostedGrantType) form.Add("device_code", info.DeviceCode) - req, err := http.NewRequest("POST", h.TokenEndpoint, strings.NewReader(form.Encode())) + req, err := http.NewRequest("POST", h.providerConfig.TokenEndpoint, strings.NewReader(form.Encode())) if err != nil { return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err) } @@ -225,18 +221,20 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription) } - err = isValidAccessToken(tokenResponse.AccessToken, h.Audience) - if err != nil { - return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err) - } - tokenInfo := TokenInfo{ AccessToken: tokenResponse.AccessToken, TokenType: tokenResponse.TokenType, RefreshToken: tokenResponse.RefreshToken, IDToken: tokenResponse.IDToken, ExpiresIn: tokenResponse.ExpiresIn, + UseIDToken: h.providerConfig.UseIDToken, } + + err = isValidAccessToken(tokenInfo.GetTokenToUse(), h.providerConfig.Audience) + if err != nil { + return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err) + } + return tokenInfo, err } } diff --git a/client/internal/oauth_test.go b/client/internal/oauth_test.go index 3a9e2a0c2..aa71fa0eb 100644 --- a/client/internal/oauth_test.go +++ b/client/internal/oauth_test.go @@ -3,14 +3,15 @@ package internal import ( "context" "fmt" - "github.com/golang-jwt/jwt" - "github.com/stretchr/testify/require" "io" "net/http" "net/url" "strings" "testing" "time" + + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/require" ) type mockHTTPClient struct { @@ -113,12 +114,15 @@ func TestHosted_RequestDeviceCode(t *testing.T) { } hosted := Hosted{ - Audience: expectedAudience, - ClientID: expectedClientID, - Scope: expectedScope, - TokenEndpoint: "test.hosted.com/token", - DeviceAuthEndpoint: "test.hosted.com/device/auth", - HTTPClient: &httpClient, + providerConfig: ProviderConfig{ + Audience: expectedAudience, + ClientID: expectedClientID, + Scope: expectedScope, + TokenEndpoint: "test.hosted.com/token", + DeviceAuthEndpoint: "test.hosted.com/device/auth", + UseIDToken: false, + }, + HTTPClient: &httpClient, } authInfo, err := hosted.RequestDeviceCode(context.TODO()) @@ -275,12 +279,15 @@ func TestHosted_WaitToken(t *testing.T) { } hosted := Hosted{ - Audience: testCase.inputAudience, - ClientID: clientID, - TokenEndpoint: "test.hosted.com/token", - DeviceAuthEndpoint: "test.hosted.com/device/auth", - HTTPClient: &httpClient, - } + providerConfig: ProviderConfig{ + Audience: testCase.inputAudience, + ClientID: clientID, + TokenEndpoint: "test.hosted.com/token", + DeviceAuthEndpoint: "test.hosted.com/device/auth", + Scope: "openid", + UseIDToken: false, + }, + HTTPClient: &httpClient} ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout) defer cancel() diff --git a/client/server/server.go b/client/server/server.go index 6d5a08c59..fba82c7e4 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -223,12 +223,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro } } - hostedClient := internal.NewHostedDeviceFlow( - providerConfig.ProviderConfig.Audience, - providerConfig.ProviderConfig.ClientID, - providerConfig.ProviderConfig.TokenEndpoint, - providerConfig.ProviderConfig.DeviceAuthEndpoint, - ) + hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig) if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) { if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) { @@ -344,7 +339,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin s.oauthAuthFlow.expiresAt = time.Now() s.mutex.Unlock() - if loginStatus, err := s.loginAttempt(ctx, "", tokenInfo.AccessToken); err != nil { + if loginStatus, err := s.loginAttempt(ctx, "", tokenInfo.GetTokenToUse()); err != nil { state.Set(loginStatus) return nil, err } diff --git a/management/cmd/management.go b/management/cmd/management.go index 38535462f..d956fcff5 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -417,6 +417,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s", u.Host, config.DeviceAuthorizationFlow.ProviderConfig.Domain) config.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host + + if config.DeviceAuthorizationFlow.ProviderConfig.Scope == "" { + config.DeviceAuthorizationFlow.ProviderConfig.Scope = server.DefaultDeviceAuthFlowScope + } } } diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index 022cc1408..ff2133526 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -1,15 +1,15 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.12.4 +// protoc v3.21.9 // source: management.proto package proto import ( - timestamp "github.com/golang/protobuf/ptypes/timestamp" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" ) @@ -611,7 +611,7 @@ type ServerKeyResponse struct { // Server's Wireguard public key Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` // Key expiration timestamp after which the key should be fetched again by the client - ExpiresAt *timestamp.Timestamp `protobuf:"bytes,2,opt,name=expiresAt,proto3" json:"expiresAt,omitempty"` + ExpiresAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=expiresAt,proto3" json:"expiresAt,omitempty"` // Version of the Wiretrustee Management Service protocol Version int32 `protobuf:"varint,3,opt,name=version,proto3" json:"version,omitempty"` } @@ -655,7 +655,7 @@ func (x *ServerKeyResponse) GetKey() string { return "" } -func (x *ServerKeyResponse) GetExpiresAt() *timestamp.Timestamp { +func (x *ServerKeyResponse) GetExpiresAt() *timestamppb.Timestamp { if x != nil { return x.ExpiresAt } @@ -1331,6 +1331,10 @@ type ProviderConfig struct { DeviceAuthEndpoint string `protobuf:"bytes,5,opt,name=DeviceAuthEndpoint,proto3" json:"DeviceAuthEndpoint,omitempty"` // TokenEndpoint is an endpoint to request auth token. TokenEndpoint string `protobuf:"bytes,6,opt,name=TokenEndpoint,proto3" json:"TokenEndpoint,omitempty"` + // Scopes provides the scopes to be included in the token request + Scope string `protobuf:"bytes,7,opt,name=Scope,proto3" json:"Scope,omitempty"` + // UseIDToken indicates if the id token should be used for authentication + UseIDToken bool `protobuf:"varint,8,opt,name=UseIDToken,proto3" json:"UseIDToken,omitempty"` } func (x *ProviderConfig) Reset() { @@ -1407,6 +1411,20 @@ func (x *ProviderConfig) GetTokenEndpoint() string { return "" } +func (x *ProviderConfig) GetScope() string { + if x != nil { + return x.Scope + } + return "" +} + +func (x *ProviderConfig) GetUseIDToken() bool { + if x != nil { + return x.UseIDToken + } + return false +} + // Route represents a route.Route object type Route struct { state protoimpl.MessageState @@ -2000,7 +2018,7 @@ var file_management_proto_rawDesc = []byte{ 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, - 0x10, 0x00, 0x22, 0xda, 0x01, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, + 0x10, 0x00, 0x22, 0x90, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, @@ -2013,81 +2031,84 @@ var file_management_proto_rawDesc = []byte{ 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x22, - 0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, - 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, - 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, - 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, - 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, - 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, - 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, - 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, - 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, - 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, - 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, - 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, - 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, - 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0x7f, - 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, - 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, - 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, - 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, - 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x22, - 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, - 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, - 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x32, 0xf7, 0x02, 0x0a, 0x11, 0x4d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, - 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, + 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, + 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, + 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, + 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, + 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, + 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, + 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, + 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, + 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, + 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x22, 0xb4, 0x01, + 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, + 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, + 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, + 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, + 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, + 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, + 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, + 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, + 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, + 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, + 0x44, 0x61, 0x74, 0x61, 0x22, 0x7f, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, + 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x32, + 0xf7, 0x02, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, - 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, - 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, - 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, - 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, + 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, + 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, + 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -2132,7 +2153,7 @@ var file_management_proto_goTypes = []interface{}{ (*SimpleRecord)(nil), // 24: management.SimpleRecord (*NameServerGroup)(nil), // 25: management.NameServerGroup (*NameServer)(nil), // 26: management.NameServer - (*timestamp.Timestamp)(nil), // 27: google.protobuf.Timestamp + (*timestamppb.Timestamp)(nil), // 27: google.protobuf.Timestamp } var file_management_proto_depIdxs = []int32{ 11, // 0: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig diff --git a/management/proto/management.proto b/management/proto/management.proto index 2c3c18c97..5447a9ee6 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -246,6 +246,10 @@ message ProviderConfig { string DeviceAuthEndpoint = 5; // TokenEndpoint is an endpoint to request auth token. string TokenEndpoint = 6; + // Scopes provides the scopes to be included in the token request + string Scope = 7; + // UseIDToken indicates if the id token should be used for authentication + bool UseIDToken = 8; } // Route represents a route.Route object diff --git a/management/server/config.go b/management/server/config.go index f8d7d8db8..9ec16b3e8 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -24,6 +24,11 @@ const ( NONE Provider = "none" ) +const ( + // DefaultDeviceAuthFlowScope defines the bare minimum scope to request in the device authorization flow + DefaultDeviceAuthFlowScope string = "openid" +) + // Config of the Management service type Config struct { Stuns []*Host @@ -49,6 +54,7 @@ func (c Config) GetAuthAudiences() []string { return audiences } + // TURNConfig is a config of the TURNCredentialsManager type TURNConfig struct { TimeBasedCredentials bool @@ -108,6 +114,10 @@ type ProviderConfig struct { TokenEndpoint string // DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code DeviceAuthEndpoint string + // Scopes provides the scopes to be included in the token request + Scope string + // UseIDToken indicates if the id token should be used for authentication + UseIDToken bool } // validateURL validates input http url diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index e43c767c3..f63a55d65 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -510,6 +510,8 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience, DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint, TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint, + Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope, + UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken, }, }