From 7d791620a624b8fd81f8b8e8aecf6fcf888e57c1 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Tue, 27 Jan 2026 09:42:20 +0100 Subject: [PATCH] Add user invite link feature for embedded IdP (#5157) --- management/server/account/manager.go | 6 + management/server/activity/codes.go | 10 + management/server/http/handler.go | 10 + .../handlers/instance/instance_handler.go | 35 + .../instance/instance_handler_test.go | 54 + .../http/handlers/users/invites_handler.go | 263 +++++ .../handlers/users/invites_handler_test.go | 642 +++++++++++ .../server/http/middleware/rate_limiter.go | 26 + .../http/middleware/rate_limiter_test.go | 158 +++ management/server/instance/manager.go | 175 +++ management/server/instance/version_test.go | 285 +++++ management/server/mock_server/account_mock.go | 48 + management/server/store/sql_store.go | 126 +- .../store/sql_store_user_invite_test.go | 520 +++++++++ management/server/store/store.go | 7 + management/server/types/user_invite.go | 201 ++++ management/server/types/user_invite_test.go | 355 ++++++ management/server/user.go | 366 ++++++ management/server/user_invite_test.go | 1010 +++++++++++++++++ shared/management/http/api/openapi.yml | 416 ++++++- shared/management/http/api/types.gen.go | 121 ++ 21 files changed, 4832 insertions(+), 2 deletions(-) create mode 100644 management/server/http/handlers/users/invites_handler.go create mode 100644 management/server/http/handlers/users/invites_handler_test.go create mode 100644 management/server/http/middleware/rate_limiter_test.go create mode 100644 management/server/instance/version_test.go create mode 100644 management/server/store/sql_store_user_invite_test.go create mode 100644 management/server/types/user_invite.go create mode 100644 management/server/types/user_invite_test.go create mode 100644 management/server/user_invite_test.go diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 11af67358..5e9bb42a2 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -30,6 +30,12 @@ type Manager interface { autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error) + CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) + AcceptUserInvite(ctx context.Context, token, password string) error + RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) + GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) + ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) + DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index e9eaa644b..e83eeb90a 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -199,6 +199,11 @@ const ( UserPasswordChanged Activity = 103 + UserInviteLinkCreated Activity = 104 + UserInviteLinkAccepted Activity = 105 + UserInviteLinkRegenerated Activity = 106 + UserInviteLinkDeleted Activity = 107 + AccountDeleted Activity = 99999 ) @@ -327,6 +332,11 @@ var activityMap = map[Activity]Code{ JobCreatedByUser: {"Create Job for peer", "peer.job.create"}, UserPasswordChanged: {"User password changed", "user.password.change"}, + + UserInviteLinkCreated: {"User invite link created", "user.invite.link.create"}, + UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"}, + UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"}, + UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"}, } // StringCode returns a string code of the activity diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 64f914afe..32a97ff44 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -68,6 +68,13 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks if err := bypass.AddBypassPath("/api/setup"); err != nil { return nil, fmt.Errorf("failed to add bypass path: %w", err) } + // Public invite endpoints (tokens start with nbi_) + if err := bypass.AddBypassPath("/api/users/invites/nbi_*"); err != nil { + return nil, fmt.Errorf("failed to add bypass path: %w", err) + } + if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil { + return nil, fmt.Errorf("failed to add bypass path: %w", err) + } var rateLimitingConfig *middleware.RateLimiterConfig if os.Getenv(rateLimitingEnabledKey) == "true" { @@ -132,6 +139,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router) peers.AddEndpoints(accountManager, router, networkMapController) users.AddEndpoints(accountManager, router) + users.AddInvitesEndpoints(accountManager, router) + users.AddPublicInvitesEndpoints(accountManager, router) setup_keys.AddEndpoints(accountManager, router) policies.AddEndpoints(accountManager, LocationManager, router) policies.AddPostureCheckEndpoints(accountManager, LocationManager, router) @@ -145,6 +154,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks recordsManager.RegisterEndpoints(router, rManager) idp.AddEndpoints(accountManager, router) instance.AddEndpoints(instanceManager, router) + instance.AddVersionEndpoint(instanceManager, router) // Mount embedded IdP handler at /oauth2 path if configured if embeddedIdpEnabled { diff --git a/management/server/http/handlers/instance/instance_handler.go b/management/server/http/handlers/instance/instance_handler.go index 889c3133e..5d8baaf8d 100644 --- a/management/server/http/handlers/instance/instance_handler.go +++ b/management/server/http/handlers/instance/instance_handler.go @@ -28,6 +28,15 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) { router.HandleFunc("/setup", h.setup).Methods("POST", "OPTIONS") } +// AddVersionEndpoint registers the authenticated version endpoint. +func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) { + h := &handler{ + instanceManager: instanceManager, + } + + router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS") +} + // getInstanceStatus returns the instance status including whether setup is required. // This endpoint is unauthenticated. func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) { @@ -65,3 +74,29 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) { Email: userData.Email, }) } + +// getVersionInfo returns version information for NetBird components. +// This endpoint requires authentication. +func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) { + versionInfo, err := h.instanceManager.GetVersionInfo(r.Context()) + if err != nil { + log.WithContext(r.Context()).Errorf("failed to get version info: %v", err) + util.WriteErrorResponse("failed to get version info", http.StatusInternalServerError, w) + return + } + + resp := api.InstanceVersionInfo{ + ManagementCurrentVersion: versionInfo.CurrentVersion, + ManagementUpdateAvailable: versionInfo.ManagementUpdateAvailable, + } + + if versionInfo.DashboardVersion != "" { + resp.DashboardAvailableVersion = &versionInfo.DashboardVersion + } + + if versionInfo.ManagementVersion != "" { + resp.ManagementAvailableVersion = &versionInfo.ManagementVersion + } + + util.WriteJSONObject(r.Context(), w, resp) +} diff --git a/management/server/http/handlers/instance/instance_handler_test.go b/management/server/http/handlers/instance/instance_handler_test.go index 7a3a2bc88..470079c85 100644 --- a/management/server/http/handlers/instance/instance_handler_test.go +++ b/management/server/http/handlers/instance/instance_handler_test.go @@ -25,6 +25,7 @@ type mockInstanceManager struct { isSetupRequired bool isSetupRequiredFn func(ctx context.Context) (bool, error) createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error) + getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error) } func (m *mockInstanceManager) IsSetupRequired(ctx context.Context) (bool, error) { @@ -66,6 +67,18 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo }, nil } +func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) { + if m.getVersionInfoFn != nil { + return m.getVersionInfoFn(ctx) + } + return &nbinstance.VersionInfo{ + CurrentVersion: "0.34.0", + DashboardVersion: "2.0.0", + ManagementVersion: "0.35.0", + ManagementUpdateAvailable: true, + }, nil +} + var _ nbinstance.Manager = (*mockInstanceManager)(nil) func setupTestRouter(manager nbinstance.Manager) *mux.Router { @@ -279,3 +292,44 @@ func TestSetup_ManagerError(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, rec.Code) } + +func TestGetVersionInfo_Success(t *testing.T) { + manager := &mockInstanceManager{} + router := mux.NewRouter() + AddVersionEndpoint(manager, router) + + req := httptest.NewRequest(http.MethodGet, "/instance/version", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var response api.InstanceVersionInfo + err := json.NewDecoder(rec.Body).Decode(&response) + require.NoError(t, err) + + assert.Equal(t, "0.34.0", response.ManagementCurrentVersion) + assert.NotNil(t, response.DashboardAvailableVersion) + assert.Equal(t, "2.0.0", *response.DashboardAvailableVersion) + assert.NotNil(t, response.ManagementAvailableVersion) + assert.Equal(t, "0.35.0", *response.ManagementAvailableVersion) + assert.True(t, response.ManagementUpdateAvailable) +} + +func TestGetVersionInfo_Error(t *testing.T) { + manager := &mockInstanceManager{ + getVersionInfoFn: func(ctx context.Context) (*nbinstance.VersionInfo, error) { + return nil, errors.New("failed to fetch versions") + }, + } + router := mux.NewRouter() + AddVersionEndpoint(manager, router) + + req := httptest.NewRequest(http.MethodGet, "/instance/version", nil) + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) +} diff --git a/management/server/http/handlers/users/invites_handler.go b/management/server/http/handlers/users/invites_handler.go new file mode 100644 index 000000000..0f0f57c29 --- /dev/null +++ b/management/server/http/handlers/users/invites_handler.go @@ -0,0 +1,263 @@ +package users + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "time" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server/account" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/http/middleware" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +// publicInviteRateLimiter limits public invite requests by IP address to prevent brute-force attacks +var publicInviteRateLimiter = middleware.NewAPIRateLimiter(&middleware.RateLimiterConfig{ + RequestsPerMinute: 10, // 10 attempts per minute per IP + Burst: 5, // Allow burst of 5 requests + CleanupInterval: 10 * time.Minute, + LimiterTTL: 30 * time.Minute, +}) + +// toUserInviteResponse converts a UserInvite to an API response. +func toUserInviteResponse(invite *types.UserInvite) api.UserInvite { + autoGroups := invite.UserInfo.AutoGroups + if autoGroups == nil { + autoGroups = []string{} + } + var inviteLink *string + if invite.InviteToken != "" { + inviteLink = &invite.InviteToken + } + return api.UserInvite{ + Id: invite.UserInfo.ID, + Email: invite.UserInfo.Email, + Name: invite.UserInfo.Name, + Role: invite.UserInfo.Role, + AutoGroups: autoGroups, + ExpiresAt: invite.InviteExpiresAt.UTC(), + CreatedAt: invite.InviteCreatedAt.UTC(), + Expired: time.Now().After(invite.InviteExpiresAt), + InviteToken: inviteLink, + } +} + +// invitesHandler handles user invite operations +type invitesHandler struct { + accountManager account.Manager +} + +// AddInvitesEndpoints registers invite-related endpoints +func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) { + h := &invitesHandler{accountManager: accountManager} + + // Authenticated endpoints (require admin) + router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS") + router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS") + router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS") +} + +// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting +func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Router) { + h := &invitesHandler{accountManager: accountManager} + + // Create a subrouter for public invite endpoints with rate limiting middleware + publicRouter := router.PathPrefix("/users/invites").Subrouter() + publicRouter.Use(publicInviteRateLimiter.Middleware) + + // Public endpoints (no auth required, protected by token and rate limited) + publicRouter.HandleFunc("/{token}", h.getInviteInfo).Methods("GET", "OPTIONS") + publicRouter.HandleFunc("/{token}/accept", h.acceptInvite).Methods("POST", "OPTIONS") +} + +// listInvites handles GET /api/users/invites +func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) { + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := make([]api.UserInvite, 0, len(invites)) + for _, invite := range invites { + resp = append(resp, toUserInviteResponse(invite)) + } + + util.WriteJSONObject(r.Context(), w, resp) +} + +// createInvite handles POST /api/users/invites +func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) { + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.UserInviteCreateRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + invite := &types.UserInfo{ + Email: req.Email, + Name: req.Name, + Role: req.Role, + AutoGroups: req.AutoGroups, + } + + expiresIn := 0 + if req.ExpiresIn != nil { + expiresIn = *req.ExpiresIn + } + + result, err := h.accountManager.CreateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, invite, expiresIn) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + result.InviteCreatedAt = time.Now().UTC() + resp := toUserInviteResponse(result) + util.WriteJSONObject(r.Context(), w, &resp) +} + +// getInviteInfo handles GET /api/users/invites/{token} +func (h *invitesHandler) getInviteInfo(w http.ResponseWriter, r *http.Request) { + + vars := mux.Vars(r) + token := vars["token"] + if token == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w) + return + } + + info, err := h.accountManager.GetUserInviteInfo(r.Context(), token) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + expiresAt := info.ExpiresAt.UTC() + util.WriteJSONObject(r.Context(), w, &api.UserInviteInfo{ + Email: info.Email, + Name: info.Name, + ExpiresAt: expiresAt, + Valid: info.Valid, + InvitedBy: info.InvitedBy, + }) +} + +// acceptInvite handles POST /api/users/invites/{token}/accept +func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) { + + vars := mux.Vars(r) + token := vars["token"] + if token == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w) + return + } + + var req api.UserInviteAcceptRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + err := h.accountManager.AcceptUserInvite(r.Context(), token, req.Password) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, &api.UserInviteAcceptResponse{Success: true}) +} + +// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate +func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return + } + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + inviteID := vars["inviteId"] + if inviteID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w) + return + } + + var req api.UserInviteRegenerateRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + // Allow empty body (io.EOF) - expiresIn is optional + if !errors.Is(err, io.EOF) { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + } + + expiresIn := 0 + if req.ExpiresIn != nil { + expiresIn = *req.ExpiresIn + } + + result, err := h.accountManager.RegenerateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID, expiresIn) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + expiresAt := result.InviteExpiresAt.UTC() + util.WriteJSONObject(r.Context(), w, &api.UserInviteRegenerateResponse{ + InviteToken: result.InviteToken, + InviteExpiresAt: expiresAt, + }) +} + +// deleteInvite handles DELETE /api/users/invites/{inviteId} +func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) { + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + inviteID := vars["inviteId"] + if inviteID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w) + return + } + + err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/server/http/handlers/users/invites_handler_test.go b/management/server/http/handlers/users/invites_handler_test.go new file mode 100644 index 000000000..80826b9d4 --- /dev/null +++ b/management/server/http/handlers/users/invites_handler_test.go @@ -0,0 +1,642 @@ +package users + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" +) + +const ( + testAccountID = "test-account-id" + testUserID = "test-user-id" + testInviteID = "test-invite-id" + testInviteToken = "nbi_testtoken123456789012345678" + testEmail = "invite@example.com" + testName = "Test User" +) + +func setupInvitesTestHandler(am *mock_server.MockAccountManager) *invitesHandler { + return &invitesHandler{ + accountManager: am, + } +} + +func TestListInvites(t *testing.T) { + now := time.Now().UTC() + testInvites := []*types.UserInvite{ + { + UserInfo: &types.UserInfo{ + ID: "invite-1", + Email: "user1@example.com", + Name: "User One", + Role: "user", + AutoGroups: []string{"group-1"}, + }, + InviteExpiresAt: now.Add(24 * time.Hour), + InviteCreatedAt: now, + }, + { + UserInfo: &types.UserInfo{ + ID: "invite-2", + Email: "user2@example.com", + Name: "User Two", + Role: "admin", + AutoGroups: nil, + }, + InviteExpiresAt: now.Add(-1 * time.Hour), // Expired + InviteCreatedAt: now.Add(-48 * time.Hour), + }, + } + + tt := []struct { + name string + expectedStatus int + mockFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) + expectedCount int + }{ + { + name: "successful list", + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) { + return testInvites, nil + }, + expectedCount: 2, + }, + { + name: "empty list", + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) { + return []*types.UserInvite{}, nil + }, + expectedCount: 0, + }, + { + name: "permission denied", + expectedStatus: http.StatusForbidden, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) { + return nil, status.NewPermissionDeniedError() + }, + expectedCount: 0, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{ + ListUserInvitesFunc: tc.mockFunc, + } + handler := setupInvitesTestHandler(am) + + req := httptest.NewRequest(http.MethodGet, "/api/users/invites", nil) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + + rr := httptest.NewRecorder() + handler.listInvites(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedStatus == http.StatusOK { + var resp []api.UserInvite + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + assert.Len(t, resp, tc.expectedCount) + } + }) + } +} + +func TestCreateInvite(t *testing.T) { + now := time.Now().UTC() + expiresAt := now.Add(72 * time.Hour) + + tt := []struct { + name string + requestBody string + expectedStatus int + mockFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) + }{ + { + name: "successful create", + requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":["group-1"]}`, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + return &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: testInviteID, + Email: invite.Email, + Name: invite.Name, + Role: invite.Role, + AutoGroups: invite.AutoGroups, + Status: string(types.UserStatusInvited), + }, + InviteToken: testInviteToken, + InviteExpiresAt: expiresAt, + }, nil + }, + }, + { + name: "successful create with custom expiration", + requestBody: `{"email":"test@example.com","name":"Test User","role":"admin","auto_groups":[],"expires_in":3600}`, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + assert.Equal(t, 3600, expiresIn) + return &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: testInviteID, + Email: invite.Email, + Name: invite.Name, + Role: invite.Role, + AutoGroups: []string{}, + Status: string(types.UserStatusInvited), + }, + InviteToken: testInviteToken, + InviteExpiresAt: expiresAt, + }, nil + }, + }, + { + name: "user already exists", + requestBody: `{"email":"existing@example.com","name":"Existing User","role":"user","auto_groups":[]}`, + expectedStatus: http.StatusConflict, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists") + }, + }, + { + name: "invite already exists", + requestBody: `{"email":"invited@example.com","name":"Invited User","role":"user","auto_groups":[]}`, + expectedStatus: http.StatusConflict, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email") + }, + }, + { + name: "permission denied", + requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`, + expectedStatus: http.StatusForbidden, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + return nil, status.NewPermissionDeniedError() + }, + }, + { + name: "embedded IDP not enabled", + requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`, + expectedStatus: http.StatusPreconditionFailed, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + }, + }, + { + name: "invalid JSON", + requestBody: `{invalid json}`, + expectedStatus: http.StatusBadRequest, + mockFunc: nil, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{ + CreateUserInviteFunc: tc.mockFunc, + } + handler := setupInvitesTestHandler(am) + + req := httptest.NewRequest(http.MethodPost, "/api/users/invites", bytes.NewBufferString(tc.requestBody)) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + + rr := httptest.NewRecorder() + handler.createInvite(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedStatus == http.StatusOK { + var resp api.UserInvite + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + assert.Equal(t, testInviteID, resp.Id) + assert.NotNil(t, resp.InviteToken) + assert.NotEmpty(t, *resp.InviteToken) + } + }) + } +} + +func TestGetInviteInfo(t *testing.T) { + now := time.Now().UTC() + + tt := []struct { + name string + token string + expectedStatus int + mockFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error) + }{ + { + name: "successful get valid invite", + token: testInviteToken, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) { + return &types.UserInviteInfo{ + Email: testEmail, + Name: testName, + ExpiresAt: now.Add(24 * time.Hour), + Valid: true, + InvitedBy: "Admin User", + }, nil + }, + }, + { + name: "successful get expired invite", + token: testInviteToken, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) { + return &types.UserInviteInfo{ + Email: testEmail, + Name: testName, + ExpiresAt: now.Add(-24 * time.Hour), + Valid: false, + InvitedBy: "Admin User", + }, nil + }, + }, + { + name: "invite not found", + token: "nbi_invalidtoken1234567890123456", + expectedStatus: http.StatusNotFound, + mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) { + return nil, status.Errorf(status.NotFound, "invite not found") + }, + }, + { + name: "invalid token format", + token: "invalid", + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) { + return nil, status.Errorf(status.InvalidArgument, "invalid invite token") + }, + }, + { + name: "missing token", + token: "", + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: nil, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{ + GetUserInviteInfoFunc: tc.mockFunc, + } + handler := setupInvitesTestHandler(am) + + req := httptest.NewRequest(http.MethodGet, "/api/users/invites/"+tc.token, nil) + if tc.token != "" { + req = mux.SetURLVars(req, map[string]string{"token": tc.token}) + } + + rr := httptest.NewRecorder() + handler.getInviteInfo(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedStatus == http.StatusOK { + var resp api.UserInviteInfo + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + assert.Equal(t, testEmail, resp.Email) + assert.Equal(t, testName, resp.Name) + } + }) + } +} + +func TestAcceptInvite(t *testing.T) { + tt := []struct { + name string + token string + requestBody string + expectedStatus int + mockFunc func(ctx context.Context, token, password string) error + }{ + { + name: "successful accept", + token: testInviteToken, + requestBody: `{"password":"SecurePass123!"}`, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, token, password string) error { + return nil + }, + }, + { + name: "invite not found", + token: "nbi_invalidtoken1234567890123456", + requestBody: `{"password":"SecurePass123!"}`, + expectedStatus: http.StatusNotFound, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.NotFound, "invite not found") + }, + }, + { + name: "invite expired", + token: testInviteToken, + requestBody: `{"password":"SecurePass123!"}`, + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.InvalidArgument, "invite has expired") + }, + }, + { + name: "embedded IDP not enabled", + token: testInviteToken, + requestBody: `{"password":"SecurePass123!"}`, + expectedStatus: http.StatusPreconditionFailed, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + }, + }, + { + name: "missing token", + token: "", + requestBody: `{"password":"SecurePass123!"}`, + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: nil, + }, + { + name: "invalid JSON", + token: testInviteToken, + requestBody: `{invalid}`, + expectedStatus: http.StatusBadRequest, + mockFunc: nil, + }, + { + name: "password too short", + token: testInviteToken, + requestBody: `{"password":"Short1!"}`, + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.InvalidArgument, "password must be at least 8 characters long") + }, + }, + { + name: "password missing digit", + token: testInviteToken, + requestBody: `{"password":"NoDigitPass!"}`, + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.InvalidArgument, "password must contain at least one digit") + }, + }, + { + name: "password missing uppercase", + token: testInviteToken, + requestBody: `{"password":"nouppercase1!"}`, + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.InvalidArgument, "password must contain at least one uppercase letter") + }, + }, + { + name: "password missing special character", + token: testInviteToken, + requestBody: `{"password":"NoSpecial123"}`, + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.InvalidArgument, "password must contain at least one special character") + }, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{ + AcceptUserInviteFunc: tc.mockFunc, + } + handler := setupInvitesTestHandler(am) + + req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.token+"/accept", bytes.NewBufferString(tc.requestBody)) + if tc.token != "" { + req = mux.SetURLVars(req, map[string]string{"token": tc.token}) + } + + rr := httptest.NewRecorder() + handler.acceptInvite(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedStatus == http.StatusOK { + var resp api.UserInviteAcceptResponse + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + assert.True(t, resp.Success) + } + }) + } +} + +func TestRegenerateInvite(t *testing.T) { + now := time.Now().UTC() + expiresAt := now.Add(72 * time.Hour) + + tt := []struct { + name string + inviteID string + requestBody string + expectedStatus int + mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) + }{ + { + name: "successful regenerate with empty body", + inviteID: testInviteID, + requestBody: "", + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) { + assert.Equal(t, 0, expiresIn) + return &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: inviteID, + Email: testEmail, + }, + InviteToken: "nbi_newtoken12345678901234567890", + InviteExpiresAt: expiresAt, + }, nil + }, + }, + { + name: "successful regenerate with custom expiration", + inviteID: testInviteID, + requestBody: `{"expires_in":7200}`, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) { + assert.Equal(t, 7200, expiresIn) + return &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: inviteID, + Email: testEmail, + }, + InviteToken: "nbi_newtoken12345678901234567890", + InviteExpiresAt: expiresAt, + }, nil + }, + }, + { + name: "invite not found", + inviteID: "non-existent-invite", + requestBody: "", + expectedStatus: http.StatusNotFound, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) { + return nil, status.Errorf(status.NotFound, "invite not found") + }, + }, + { + name: "permission denied", + inviteID: testInviteID, + requestBody: "", + expectedStatus: http.StatusForbidden, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) { + return nil, status.NewPermissionDeniedError() + }, + }, + { + name: "missing invite ID", + inviteID: "", + requestBody: "", + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: nil, + }, + { + name: "invalid JSON should return error", + inviteID: testInviteID, + requestBody: `{invalid json}`, + expectedStatus: http.StatusBadRequest, + mockFunc: nil, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{ + RegenerateUserInviteFunc: tc.mockFunc, + } + handler := setupInvitesTestHandler(am) + + var body io.Reader + if tc.requestBody != "" { + body = bytes.NewBufferString(tc.requestBody) + } + + req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.inviteID+"/regenerate", body) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + if tc.inviteID != "" { + req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID}) + } + + rr := httptest.NewRecorder() + handler.regenerateInvite(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedStatus == http.StatusOK { + var resp api.UserInviteRegenerateResponse + err := json.NewDecoder(rr.Body).Decode(&resp) + require.NoError(t, err) + assert.NotEmpty(t, resp.InviteToken) + } + }) + } +} + +func TestDeleteInvite(t *testing.T) { + tt := []struct { + name string + inviteID string + expectedStatus int + mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error + }{ + { + name: "successful delete", + inviteID: testInviteID, + expectedStatus: http.StatusOK, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error { + return nil + }, + }, + { + name: "invite not found", + inviteID: "non-existent-invite", + expectedStatus: http.StatusNotFound, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error { + return status.Errorf(status.NotFound, "invite not found") + }, + }, + { + name: "permission denied", + inviteID: testInviteID, + expectedStatus: http.StatusForbidden, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error { + return status.NewPermissionDeniedError() + }, + }, + { + name: "embedded IDP not enabled", + inviteID: testInviteID, + expectedStatus: http.StatusPreconditionFailed, + mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error { + return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + }, + }, + { + name: "missing invite ID", + inviteID: "", + expectedStatus: http.StatusUnprocessableEntity, + mockFunc: nil, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{ + DeleteUserInviteFunc: tc.mockFunc, + } + handler := setupInvitesTestHandler(am) + + req := httptest.NewRequest(http.MethodDelete, "/api/users/invites/"+tc.inviteID, nil) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + }) + if tc.inviteID != "" { + req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID}) + } + + rr := httptest.NewRecorder() + handler.deleteInvite(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + }) + } +} diff --git a/management/server/http/middleware/rate_limiter.go b/management/server/http/middleware/rate_limiter.go index a6266d4f3..936b34319 100644 --- a/management/server/http/middleware/rate_limiter.go +++ b/management/server/http/middleware/rate_limiter.go @@ -2,10 +2,14 @@ package middleware import ( "context" + "net" + "net/http" "sync" "time" "golang.org/x/time/rate" + + "github.com/netbirdio/netbird/shared/management/http/util" ) // RateLimiterConfig holds configuration for the API rate limiter @@ -144,3 +148,25 @@ func (rl *APIRateLimiter) Reset(key string) { defer rl.mu.Unlock() delete(rl.limiters, key) } + +// Middleware returns an HTTP middleware that rate limits requests by client IP. +// Returns 429 Too Many Requests if the rate limit is exceeded. +func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + clientIP := getClientIP(r) + if !rl.Allow(clientIP) { + util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w) + return + } + next.ServeHTTP(w, r) + }) +} + +// getClientIP extracts the client IP address from the request. +func getClientIP(r *http.Request) string { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return ip +} diff --git a/management/server/http/middleware/rate_limiter_test.go b/management/server/http/middleware/rate_limiter_test.go new file mode 100644 index 000000000..68f804e57 --- /dev/null +++ b/management/server/http/middleware/rate_limiter_test.go @@ -0,0 +1,158 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestAPIRateLimiter_Allow(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, // 1 per second + Burst: 2, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + // First two requests should be allowed (burst) + assert.True(t, rl.Allow("test-key")) + assert.True(t, rl.Allow("test-key")) + + // Third request should be denied (exceeded burst) + assert.False(t, rl.Allow("test-key")) + + // Different key should be allowed + assert.True(t, rl.Allow("different-key")) +} + +func TestAPIRateLimiter_Middleware(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, // 1 per second + Burst: 2, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + // Create a simple handler that returns 200 OK + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Wrap with rate limiter middleware + handler := rl.Middleware(nextHandler) + + // First two requests should pass (burst) + for i := 0; i < 2; i++ { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, "request %d should be allowed", i+1) + } + + // Third request should be rate limited + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusTooManyRequests, rr.Code) +} + +func TestAPIRateLimiter_Middleware_DifferentIPs(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler := rl.Middleware(nextHandler) + + // Request from first IP + req1 := httptest.NewRequest(http.MethodGet, "/test", nil) + req1.RemoteAddr = "192.168.1.1:12345" + rr1 := httptest.NewRecorder() + handler.ServeHTTP(rr1, req1) + assert.Equal(t, http.StatusOK, rr1.Code) + + // Second request from first IP should be rate limited + req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + req2.RemoteAddr = "192.168.1.1:12345" + rr2 := httptest.NewRecorder() + handler.ServeHTTP(rr2, req2) + assert.Equal(t, http.StatusTooManyRequests, rr2.Code) + + // Request from different IP should be allowed + req3 := httptest.NewRequest(http.MethodGet, "/test", nil) + req3.RemoteAddr = "192.168.1.2:12345" + rr3 := httptest.NewRecorder() + handler.ServeHTTP(rr3, req3) + assert.Equal(t, http.StatusOK, rr3.Code) +} + +func TestGetClientIP(t *testing.T) { + tests := []struct { + name string + remoteAddr string + expected string + }{ + { + name: "remote addr with port", + remoteAddr: "192.168.1.1:12345", + expected: "192.168.1.1", + }, + { + name: "remote addr without port", + remoteAddr: "192.168.1.1", + expected: "192.168.1.1", + }, + { + name: "IPv6 with port", + remoteAddr: "[::1]:12345", + expected: "::1", + }, + { + name: "IPv6 without port", + remoteAddr: "::1", + expected: "::1", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = tc.remoteAddr + assert.Equal(t, tc.expected, getClientIP(req)) + }) + } +} + +func TestAPIRateLimiter_Reset(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + // Use up the burst + assert.True(t, rl.Allow("test-key")) + assert.False(t, rl.Allow("test-key")) + + // Reset the limiter + rl.Reset("test-key") + + // Should be allowed again + assert.True(t, rl.Allow("test-key")) +} diff --git a/management/server/instance/manager.go b/management/server/instance/manager.go index 6f50e3ff7..6a0509ebd 100644 --- a/management/server/instance/manager.go +++ b/management/server/instance/manager.go @@ -2,18 +2,54 @@ package instance import ( "context" + "encoding/json" "errors" "fmt" + "io" + "net/http" "net/mail" + "strings" "sync" + "time" + goversion "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/version" ) +const ( + // Version endpoints + managementVersionURL = "https://pkgs.netbird.io/releases/latest/version" + dashboardReleasesURL = "https://api.github.com/repos/netbirdio/dashboard/releases/latest" + + // Cache TTL for version information + versionCacheTTL = 60 * time.Minute + + // HTTP client timeout + httpTimeout = 5 * time.Second +) + +// VersionInfo contains version information for NetBird components +type VersionInfo struct { + // CurrentVersion is the running management server version + CurrentVersion string + // DashboardVersion is the latest available dashboard version from GitHub + DashboardVersion string + // ManagementVersion is the latest available management version from GitHub + ManagementVersion string + // ManagementUpdateAvailable indicates if a newer management version is available + ManagementUpdateAvailable bool +} + +// githubRelease represents a GitHub release response +type githubRelease struct { + TagName string `json:"tag_name"` +} + // Manager handles instance-level operations like initial setup. type Manager interface { // IsSetupRequired checks if instance setup is required. @@ -23,6 +59,9 @@ type Manager interface { // CreateOwnerUser creates the initial owner user in the embedded IDP. // This should only be called when IsSetupRequired returns true. CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) + + // GetVersionInfo returns version information for NetBird components. + GetVersionInfo(ctx context.Context) (*VersionInfo, error) } // DefaultManager is the default implementation of Manager. @@ -32,6 +71,12 @@ type DefaultManager struct { setupRequired bool setupMu sync.RWMutex + + // Version caching + httpClient *http.Client + versionMu sync.RWMutex + cachedVersions *VersionInfo + lastVersionFetch time.Time } // NewManager creates a new instance manager. @@ -43,6 +88,9 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager) store: store, embeddedIdpManager: embeddedIdp, setupRequired: false, + httpClient: &http.Client{ + Timeout: httpTimeout, + }, } if embeddedIdp != nil { @@ -134,3 +182,130 @@ func (m *DefaultManager) validateSetupInfo(email, password, name string) error { } return nil } + +// GetVersionInfo returns version information for NetBird components. +func (m *DefaultManager) GetVersionInfo(ctx context.Context) (*VersionInfo, error) { + m.versionMu.RLock() + if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL { + cached := *m.cachedVersions + m.versionMu.RUnlock() + return &cached, nil + } + m.versionMu.RUnlock() + + return m.fetchVersionInfo(ctx) +} + +func (m *DefaultManager) fetchVersionInfo(ctx context.Context) (*VersionInfo, error) { + m.versionMu.Lock() + // Double-check after acquiring write lock + if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL { + cached := *m.cachedVersions + m.versionMu.Unlock() + return &cached, nil + } + m.versionMu.Unlock() + + info := &VersionInfo{ + CurrentVersion: version.NetbirdVersion(), + } + + // Fetch management version from pkgs.netbird.io (plain text) + mgmtVersion, err := m.fetchPlainTextVersion(ctx, managementVersionURL) + if err != nil { + log.WithContext(ctx).Warnf("failed to fetch management version: %v", err) + } else { + info.ManagementVersion = mgmtVersion + info.ManagementUpdateAvailable = isNewerVersion(info.CurrentVersion, mgmtVersion) + } + + // Fetch dashboard version from GitHub + dashVersion, err := m.fetchGitHubRelease(ctx, dashboardReleasesURL) + if err != nil { + log.WithContext(ctx).Warnf("failed to fetch dashboard version from GitHub: %v", err) + } else { + info.DashboardVersion = dashVersion + } + + // Update cache + m.versionMu.Lock() + m.cachedVersions = info + m.lastVersionFetch = time.Now() + m.versionMu.Unlock() + + return info, nil +} + +// isNewerVersion returns true if latestVersion is greater than currentVersion +func isNewerVersion(currentVersion, latestVersion string) bool { + current, err := goversion.NewVersion(currentVersion) + if err != nil { + return false + } + + latest, err := goversion.NewVersion(latestVersion) + if err != nil { + return false + } + + return latest.GreaterThan(current) +} + +func (m *DefaultManager) fetchPlainTextVersion(ctx context.Context, url string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", fmt.Errorf("create request: %w", err) + } + + req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion()) + + resp, err := m.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 100)) + if err != nil { + return "", fmt.Errorf("read response: %w", err) + } + + return strings.TrimSpace(string(body)), nil +} + +func (m *DefaultManager) fetchGitHubRelease(ctx context.Context, url string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Accept", "application/vnd.github.v3+json") + req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion()) + + resp, err := m.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("execute request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var release githubRelease + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return "", fmt.Errorf("decode response: %w", err) + } + + // Remove 'v' prefix if present + tag := release.TagName + if len(tag) > 0 && tag[0] == 'v' { + tag = tag[1:] + } + + return tag, nil +} diff --git a/management/server/instance/version_test.go b/management/server/instance/version_test.go new file mode 100644 index 000000000..35ba66db8 --- /dev/null +++ b/management/server/instance/version_test.go @@ -0,0 +1,285 @@ +package instance + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockRoundTripper implements http.RoundTripper for testing +type mockRoundTripper struct { + callCount atomic.Int32 + managementVersion string + dashboardVersion string +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + m.callCount.Add(1) + + var body string + if strings.Contains(req.URL.String(), "pkgs.netbird.io") { + // Plain text response for management version + body = m.managementVersion + } else if strings.Contains(req.URL.String(), "github.com") { + // JSON response for dashboard version + jsonResp, _ := json.Marshal(githubRelease{TagName: "v" + m.dashboardVersion}) + body = string(jsonResp) + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(body)), + Header: make(http.Header), + }, nil +} + +func TestDefaultManager_GetVersionInfo_ReturnsCurrentVersion(t *testing.T) { + mockTransport := &mockRoundTripper{ + managementVersion: "0.65.0", + dashboardVersion: "2.10.0", + } + + m := &DefaultManager{ + httpClient: &http.Client{Transport: mockTransport}, + } + + ctx := context.Background() + + info, err := m.GetVersionInfo(ctx) + require.NoError(t, err) + + // CurrentVersion should always be set + assert.NotEmpty(t, info.CurrentVersion) + assert.Equal(t, "0.65.0", info.ManagementVersion) + assert.Equal(t, "2.10.0", info.DashboardVersion) + assert.Equal(t, int32(2), mockTransport.callCount.Load()) // 2 calls: management + dashboard +} + +func TestDefaultManager_GetVersionInfo_CachesResults(t *testing.T) { + mockTransport := &mockRoundTripper{ + managementVersion: "0.65.0", + dashboardVersion: "2.10.0", + } + + m := &DefaultManager{ + httpClient: &http.Client{Transport: mockTransport}, + } + + ctx := context.Background() + + // First call + info1, err := m.GetVersionInfo(ctx) + require.NoError(t, err) + assert.NotEmpty(t, info1.CurrentVersion) + assert.Equal(t, "0.65.0", info1.ManagementVersion) + + initialCallCount := mockTransport.callCount.Load() + + // Second call should use cache (no additional HTTP calls) + info2, err := m.GetVersionInfo(ctx) + require.NoError(t, err) + assert.Equal(t, info1.CurrentVersion, info2.CurrentVersion) + assert.Equal(t, info1.ManagementVersion, info2.ManagementVersion) + assert.Equal(t, info1.DashboardVersion, info2.DashboardVersion) + + // Verify no additional HTTP calls were made (cache was used) + assert.Equal(t, initialCallCount, mockTransport.callCount.Load()) +} + +func TestDefaultManager_FetchGitHubRelease_ParsesTagName(t *testing.T) { + tests := []struct { + name string + tagName string + expected string + shouldError bool + }{ + { + name: "tag with v prefix", + tagName: "v1.2.3", + expected: "1.2.3", + }, + { + name: "tag without v prefix", + tagName: "1.2.3", + expected: "1.2.3", + }, + { + name: "tag with prerelease", + tagName: "v2.0.0-beta.1", + expected: "2.0.0-beta.1", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(githubRelease{TagName: tc.tagName}) + })) + defer server.Close() + + m := &DefaultManager{ + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + + version, err := m.fetchGitHubRelease(context.Background(), server.URL) + + if tc.shouldError { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.expected, version) + } + }) + } +} + +func TestDefaultManager_FetchGitHubRelease_HandlesErrors(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + }{ + { + name: "not found", + statusCode: http.StatusNotFound, + body: `{"message": "Not Found"}`, + }, + { + name: "rate limited", + statusCode: http.StatusForbidden, + body: `{"message": "API rate limit exceeded"}`, + }, + { + name: "server error", + statusCode: http.StatusInternalServerError, + body: `{"message": "Internal Server Error"}`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.statusCode) + _, _ = w.Write([]byte(tc.body)) + })) + defer server.Close() + + m := &DefaultManager{ + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + + _, err := m.fetchGitHubRelease(context.Background(), server.URL) + assert.Error(t, err) + }) + } +} + +func TestDefaultManager_FetchGitHubRelease_InvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{invalid json}`)) + })) + defer server.Close() + + m := &DefaultManager{ + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + + _, err := m.fetchGitHubRelease(context.Background(), server.URL) + assert.Error(t, err) +} + +func TestDefaultManager_FetchGitHubRelease_ContextCancellation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(1 * time.Second) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(githubRelease{TagName: "v1.0.0"}) + })) + defer server.Close() + + m := &DefaultManager{ + httpClient: &http.Client{Timeout: 5 * time.Second}, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + _, err := m.fetchGitHubRelease(ctx, server.URL) + assert.Error(t, err) +} + +func TestIsNewerVersion(t *testing.T) { + tests := []struct { + name string + currentVersion string + latestVersion string + expected bool + }{ + { + name: "latest is newer - minor version", + currentVersion: "0.64.1", + latestVersion: "0.65.0", + expected: true, + }, + { + name: "latest is newer - patch version", + currentVersion: "0.64.1", + latestVersion: "0.64.2", + expected: true, + }, + { + name: "latest is newer - major version", + currentVersion: "0.64.1", + latestVersion: "1.0.0", + expected: true, + }, + { + name: "versions are equal", + currentVersion: "0.64.1", + latestVersion: "0.64.1", + expected: false, + }, + { + name: "current is newer - minor version", + currentVersion: "0.65.0", + latestVersion: "0.64.1", + expected: false, + }, + { + name: "current is newer - patch version", + currentVersion: "0.64.2", + latestVersion: "0.64.1", + expected: false, + }, + { + name: "development version", + currentVersion: "development", + latestVersion: "0.65.0", + expected: false, + }, + { + name: "invalid latest version", + currentVersion: "0.64.1", + latestVersion: "invalid", + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := isNewerVersion(tc.currentVersion, tc.latestVersion) + assert.Equal(t, tc.expected, result) + }) + } +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 75e971498..026989898 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -139,6 +139,12 @@ type MockAccountManager struct { CreatePeerJobFunc func(ctx context.Context, accountID, peerID, userID string, job *types.Job) error GetAllPeerJobsFunc func(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) GetPeerJobByIDFunc func(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) + CreateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) + AcceptUserInviteFunc func(ctx context.Context, token, password string) error + RegenerateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) + GetUserInviteInfoFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error) + ListUserInvitesFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) + DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error } func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error { @@ -713,6 +719,48 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") } +func (am *MockAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + if am.CreateUserInviteFunc != nil { + return am.CreateUserInviteFunc(ctx, accountID, initiatorUserID, invite, expiresIn) + } + return nil, status.Errorf(codes.Unimplemented, "method CreateUserInvite is not implemented") +} + +func (am *MockAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error { + if am.AcceptUserInviteFunc != nil { + return am.AcceptUserInviteFunc(ctx, token, password) + } + return status.Errorf(codes.Unimplemented, "method AcceptUserInvite is not implemented") +} + +func (am *MockAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) { + if am.RegenerateUserInviteFunc != nil { + return am.RegenerateUserInviteFunc(ctx, accountID, initiatorUserID, inviteID, expiresIn) + } + return nil, status.Errorf(codes.Unimplemented, "method RegenerateUserInvite is not implemented") +} + +func (am *MockAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) { + if am.GetUserInviteInfoFunc != nil { + return am.GetUserInviteInfoFunc(ctx, token) + } + return nil, status.Errorf(codes.Unimplemented, "method GetUserInviteInfo is not implemented") +} + +func (am *MockAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) { + if am.ListUserInvitesFunc != nil { + return am.ListUserInvitesFunc(ctx, accountID, initiatorUserID) + } + return nil, status.Errorf(codes.Unimplemented, "method ListUserInvites is not implemented") +} + +func (am *MockAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error { + if am.DeleteUserInviteFunc != nil { + return am.DeleteUserInviteFunc(ctx, accountID, initiatorUserID, inviteID) + } + return status.Errorf(codes.Unimplemented, "method DeleteUserInvite is not implemented") +} + func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) { if am.GetAccountIDFromUserAuthFunc != nil { return am.GetAccountIDFromUserAuthFunc(ctx, userAuth) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 4fe800636..7f48f510e 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -126,7 +126,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, - &types.Job{}, &zones.Zone{}, &records.Record{}, + &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, ) if err != nil { return nil, fmt.Errorf("auto migratePreAuto: %w", err) @@ -815,6 +815,130 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre return &user, nil } +// SaveUserInvite saves a user invite to the database +func (s *SqlStore) SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error { + inviteCopy := invite.Copy() + if err := inviteCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { + return fmt.Errorf("encrypt invite: %w", err) + } + + result := s.db.Save(inviteCopy) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save user invite to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save user invite to store") + } + return nil +} + +// GetUserInviteByID retrieves a user invite by its ID and account ID +func (s *SqlStore) GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var invite types.UserInviteRecord + result := tx.Where("account_id = ?", accountID).Take(&invite, idQueryCondition, inviteID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "user invite not found") + } + log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get user invite from store") + } + + if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt invite: %w", err) + } + + return &invite, nil +} + +// GetUserInviteByHashedToken retrieves a user invite by its hashed token +func (s *SqlStore) GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var invite types.UserInviteRecord + result := tx.Take(&invite, "hashed_token = ?", hashedToken) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "user invite not found") + } + log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get user invite from store") + } + + if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt invite: %w", err) + } + + return &invite, nil +} + +// GetUserInviteByEmail retrieves a user invite by account ID and email. +// Since email is encrypted with random IVs, we fetch all invites for the account +// and compare emails in memory after decryption. +func (s *SqlStore) GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var invites []*types.UserInviteRecord + result := tx.Find(&invites, "account_id = ?", accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get user invites from store") + } + + for _, invite := range invites { + if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt invite: %w", err) + } + if strings.EqualFold(invite.Email, email) { + return invite, nil + } + } + + return nil, status.Errorf(status.NotFound, "user invite not found for email") +} + +// GetAccountUserInvites retrieves all user invites for an account +func (s *SqlStore) GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var invites []*types.UserInviteRecord + result := tx.Find(&invites, "account_id = ?", accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get user invites from store") + } + + for _, invite := range invites { + if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt invite: %w", err) + } + } + + return invites, nil +} + +// DeleteUserInvite deletes a user invite by its ID +func (s *SqlStore) DeleteUserInvite(ctx context.Context, inviteID string) error { + result := s.db.Delete(&types.UserInviteRecord{}, idQueryCondition, inviteID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete user invite from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete user invite from store") + } + return nil +} + func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) { tx := s.db if lockStrength != LockingStrengthNone { diff --git a/management/server/store/sql_store_user_invite_test.go b/management/server/store/sql_store_user_invite_test.go new file mode 100644 index 000000000..fb6934a2e --- /dev/null +++ b/management/server/store/sql_store_user_invite_test.go @@ -0,0 +1,520 @@ +package store + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/types" +) + +func TestSqlStore_SaveUserInvite(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-1", + AccountID: "account-1", + Email: "test@example.com", + Name: "Test User", + Role: "user", + AutoGroups: []string{"group-1", "group-2"}, + HashedToken: "hashed-token-123", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Verify the invite was saved + retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + assert.Equal(t, invite.Email, retrieved.Email) + assert.Equal(t, invite.Name, retrieved.Name) + assert.Equal(t, invite.Role, retrieved.Role) + assert.Equal(t, invite.AutoGroups, retrieved.AutoGroups) + assert.Equal(t, invite.CreatedBy, retrieved.CreatedBy) + }) +} + +func TestSqlStore_SaveUserInvite_Update(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-update", + AccountID: "account-1", + Email: "test@example.com", + Name: "Test User", + Role: "user", + AutoGroups: []string{"group-1"}, + HashedToken: "hashed-token-123", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Update the invite with a new token + invite.HashedToken = "new-hashed-token" + invite.ExpiresAt = time.Now().Add(24 * time.Hour) + + err = store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Verify the update + retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + assert.Equal(t, "new-hashed-token", retrieved.HashedToken) + }) +} + +func TestSqlStore_GetUserInviteByID(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-get-by-id", + AccountID: "account-1", + Email: "getbyid@example.com", + Name: "Get By ID User", + Role: "admin", + AutoGroups: []string{}, + HashedToken: "hashed-token-get", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Get by ID - success + retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + assert.Equal(t, invite.Email, retrieved.Email) + + // Get by ID - wrong account + _, err = store.GetUserInviteByID(ctx, LockingStrengthNone, "wrong-account", invite.ID) + assert.Error(t, err) + + // Get by ID - not found + _, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, "non-existent") + assert.Error(t, err) + }) +} + +func TestSqlStore_GetUserInviteByHashedToken(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-get-by-token", + AccountID: "account-1", + Email: "getbytoken@example.com", + Name: "Get By Token User", + Role: "user", + AutoGroups: []string{"group-1"}, + HashedToken: "unique-hashed-token-456", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Get by hashed token - success + retrieved, err := store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, invite.HashedToken) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + assert.Equal(t, invite.Email, retrieved.Email) + + // Get by hashed token - not found + _, err = store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, "non-existent-token") + assert.Error(t, err) + }) +} + +func TestSqlStore_GetUserInviteByEmail(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-get-by-email", + AccountID: "account-email-test", + Email: "unique-email@example.com", + Name: "Get By Email User", + Role: "user", + AutoGroups: []string{}, + HashedToken: "hashed-token-email", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Get by email - success + retrieved, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, invite.Email) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + + // Get by email - case insensitive + retrieved, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "UNIQUE-EMAIL@EXAMPLE.COM") + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + + // Get by email - wrong account + _, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, "wrong-account", invite.Email) + assert.Error(t, err) + + // Get by email - not found + _, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "nonexistent@example.com") + assert.Error(t, err) + }) +} + +func TestSqlStore_GetAccountUserInvites(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + accountID := "account-list-invites" + + invites := []*types.UserInviteRecord{ + { + ID: "invite-list-1", + AccountID: accountID, + Email: "user1@example.com", + Name: "User One", + Role: "user", + AutoGroups: []string{"group-1"}, + HashedToken: "hashed-token-list-1", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + }, + { + ID: "invite-list-2", + AccountID: accountID, + Email: "user2@example.com", + Name: "User Two", + Role: "admin", + AutoGroups: []string{"group-2"}, + HashedToken: "hashed-token-list-2", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + }, + { + ID: "invite-list-3", + AccountID: "different-account", + Email: "user3@example.com", + Name: "User Three", + Role: "user", + AutoGroups: []string{}, + HashedToken: "hashed-token-list-3", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + }, + } + + for _, invite := range invites { + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + } + + // Get all invites for the account + retrieved, err := store.GetAccountUserInvites(ctx, LockingStrengthNone, accountID) + require.NoError(t, err) + assert.Len(t, retrieved, 2) + + // Verify the invites belong to the correct account + for _, invite := range retrieved { + assert.Equal(t, accountID, invite.AccountID) + } + + // Get invites for account with no invites + retrieved, err = store.GetAccountUserInvites(ctx, LockingStrengthNone, "empty-account") + require.NoError(t, err) + assert.Len(t, retrieved, 0) + }) +} + +func TestSqlStore_DeleteUserInvite(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-delete", + AccountID: "account-delete-test", + Email: "delete@example.com", + Name: "Delete User", + Role: "user", + AutoGroups: []string{}, + HashedToken: "hashed-token-delete", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Verify invite exists + _, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + + // Delete the invite + err = store.DeleteUserInvite(ctx, invite.ID) + require.NoError(t, err) + + // Verify invite is deleted + _, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + assert.Error(t, err) + }) +} + +func TestSqlStore_UserInvite_EncryptedFields(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-encrypted", + AccountID: "account-encrypted", + Email: "sensitive-email@example.com", + Name: "Sensitive Name", + Role: "user", + AutoGroups: []string{"group-1"}, + HashedToken: "hashed-token-encrypted", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Retrieve and verify decryption works + retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + assert.Equal(t, "sensitive-email@example.com", retrieved.Email) + assert.Equal(t, "Sensitive Name", retrieved.Name) + }) +} + +func TestSqlStore_DeleteUserInvite_NonExistent(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + // Deleting a non-existent invite should not return an error + err := store.DeleteUserInvite(ctx, "non-existent-invite-id") + require.NoError(t, err) + }) +} + +func TestSqlStore_UserInvite_SameEmailDifferentAccounts(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + email := "shared-email@example.com" + + // Create invite in first account + invite1 := &types.UserInviteRecord{ + ID: "invite-account1", + AccountID: "account-1", + Email: email, + Name: "User Account 1", + Role: "user", + AutoGroups: []string{}, + HashedToken: "hashed-token-account1", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-1", + } + + // Create invite in second account with same email + invite2 := &types.UserInviteRecord{ + ID: "invite-account2", + AccountID: "account-2", + Email: email, + Name: "User Account 2", + Role: "admin", + AutoGroups: []string{"group-1"}, + HashedToken: "hashed-token-account2", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-2", + } + + err := store.SaveUserInvite(ctx, invite1) + require.NoError(t, err) + + err = store.SaveUserInvite(ctx, invite2) + require.NoError(t, err) + + // Verify each account gets the correct invite by email + retrieved1, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-1", email) + require.NoError(t, err) + assert.Equal(t, "invite-account1", retrieved1.ID) + assert.Equal(t, "User Account 1", retrieved1.Name) + + retrieved2, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-2", email) + require.NoError(t, err) + assert.Equal(t, "invite-account2", retrieved2.ID) + assert.Equal(t, "User Account 2", retrieved2.Name) + }) +} + +func TestSqlStore_UserInvite_LockingStrength(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + invite := &types.UserInviteRecord{ + ID: "invite-locking", + AccountID: "account-locking", + Email: "locking@example.com", + Name: "Locking Test User", + Role: "user", + AutoGroups: []string{}, + HashedToken: "hashed-token-locking", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + // Test with different locking strengths + lockStrengths := []LockingStrength{LockingStrengthNone, LockingStrengthShare, LockingStrengthUpdate} + + for _, strength := range lockStrengths { + retrieved, err := store.GetUserInviteByID(ctx, strength, invite.AccountID, invite.ID) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + + retrieved, err = store.GetUserInviteByHashedToken(ctx, strength, invite.HashedToken) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + + retrieved, err = store.GetUserInviteByEmail(ctx, strength, invite.AccountID, invite.Email) + require.NoError(t, err) + assert.Equal(t, invite.ID, retrieved.ID) + + invites, err := store.GetAccountUserInvites(ctx, strength, invite.AccountID) + require.NoError(t, err) + assert.Len(t, invites, 1) + } + }) +} + +func TestSqlStore_UserInvite_EmptyAutoGroups(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + // Test with nil AutoGroups + invite := &types.UserInviteRecord{ + ID: "invite-nil-autogroups", + AccountID: "account-autogroups", + Email: "nilgroups@example.com", + Name: "Nil Groups User", + Role: "user", + AutoGroups: nil, + HashedToken: "hashed-token-nil", + ExpiresAt: time.Now().Add(72 * time.Hour), + CreatedAt: time.Now(), + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + // Should return empty slice or nil, both are acceptable + assert.Empty(t, retrieved.AutoGroups) + }) +} + +func TestSqlStore_UserInvite_TimestampPrecision(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Skip("store is nil") + } + ctx := context.Background() + + now := time.Now().UTC().Truncate(time.Millisecond) + expiresAt := now.Add(72 * time.Hour) + + invite := &types.UserInviteRecord{ + ID: "invite-timestamp", + AccountID: "account-timestamp", + Email: "timestamp@example.com", + Name: "Timestamp User", + Role: "user", + AutoGroups: []string{}, + HashedToken: "hashed-token-timestamp", + ExpiresAt: expiresAt, + CreatedAt: now, + CreatedBy: "admin-user", + } + + err := store.SaveUserInvite(ctx, invite) + require.NoError(t, err) + + retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID) + require.NoError(t, err) + + // Verify timestamps are preserved (within reasonable precision) + assert.WithinDuration(t, now, retrieved.CreatedAt, time.Second) + assert.WithinDuration(t, expiresAt, retrieved.ExpiresAt, time.Second) + }) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 02c746592..be0d29768 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -92,6 +92,13 @@ type Store interface { DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error + SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error + GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error) + GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error) + GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error) + GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error) + DeleteUserInvite(ctx context.Context, inviteID string) error + GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) diff --git a/management/server/types/user_invite.go b/management/server/types/user_invite.go new file mode 100644 index 000000000..1544b0ff3 --- /dev/null +++ b/management/server/types/user_invite.go @@ -0,0 +1,201 @@ +package types + +import ( + "crypto/sha256" + b64 "encoding/base64" + "fmt" + "hash/crc32" + "strings" + "time" + + b "github.com/hashicorp/go-secure-stdlib/base62" + "github.com/rs/xid" + + "github.com/netbirdio/netbird/base62" + "github.com/netbirdio/netbird/util/crypt" +) + +const ( + // InviteTokenPrefix is the prefix for invite tokens + InviteTokenPrefix = "nbi_" + // InviteTokenSecretLength is the length of the random secret part + InviteTokenSecretLength = 30 + // InviteTokenChecksumLength is the length of the encoded checksum + InviteTokenChecksumLength = 6 + // InviteTokenLength is the total length of the token (4 + 30 + 6 = 40) + InviteTokenLength = 40 + // DefaultInviteExpirationSeconds is the default expiration time for invites (72 hours) + DefaultInviteExpirationSeconds = 259200 + // MinInviteExpirationSeconds is the minimum expiration time for invites (1 hour) + MinInviteExpirationSeconds = 3600 +) + +// UserInviteRecord represents an invitation for a user to set up their account (database model) +type UserInviteRecord struct { + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index;not null"` + Email string `gorm:"index;not null"` + Name string `gorm:"not null"` + Role string `gorm:"not null"` + AutoGroups []string `gorm:"serializer:json"` + HashedToken string `gorm:"index;not null"` // SHA-256 hash of the token (base64 encoded) + ExpiresAt time.Time `gorm:"not null"` + CreatedAt time.Time `gorm:"not null"` + CreatedBy string `gorm:"not null"` +} + +// TableName returns the table name for GORM +func (UserInviteRecord) TableName() string { + return "user_invites" +} + +// GenerateInviteToken creates a new invite token with the format: nbi_ +// Returns the hashed token (for storage) and the plain token (to give to the user) +func GenerateInviteToken() (hashedToken string, plainToken string, err error) { + secret, err := b.Random(InviteTokenSecretLength) + if err != nil { + return "", "", fmt.Errorf("failed to generate random secret: %w", err) + } + + checksum := crc32.ChecksumIEEE([]byte(secret)) + encodedChecksum := base62.Encode(checksum) + // Left-pad with '0' to ensure exactly 6 characters (fmt.Sprintf %s pads with spaces which breaks base62.Decode) + paddedChecksum := encodedChecksum + if len(paddedChecksum) < InviteTokenChecksumLength { + paddedChecksum = strings.Repeat("0", InviteTokenChecksumLength-len(paddedChecksum)) + paddedChecksum + } + + plainToken = InviteTokenPrefix + secret + paddedChecksum + hash := sha256.Sum256([]byte(plainToken)) + hashedToken = b64.StdEncoding.EncodeToString(hash[:]) + + return hashedToken, plainToken, nil +} + +// HashInviteToken creates a SHA-256 hash of the token (base64 encoded) +func HashInviteToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return b64.StdEncoding.EncodeToString(hash[:]) +} + +// ValidateInviteToken validates the token format and checksum. +// Returns an error if the token is invalid. +func ValidateInviteToken(token string) error { + if len(token) != InviteTokenLength { + return fmt.Errorf("invalid token length") + } + + prefix := token[:len(InviteTokenPrefix)] + if prefix != InviteTokenPrefix { + return fmt.Errorf("invalid token prefix") + } + + secret := token[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength] + encodedChecksum := token[len(InviteTokenPrefix)+InviteTokenSecretLength:] + + verificationChecksum, err := base62.Decode(encodedChecksum) + if err != nil { + return fmt.Errorf("checksum decoding failed: %w", err) + } + + secretChecksum := crc32.ChecksumIEEE([]byte(secret)) + if secretChecksum != verificationChecksum { + return fmt.Errorf("checksum does not match") + } + + return nil +} + +// IsExpired checks if the invite has expired +func (i *UserInviteRecord) IsExpired() bool { + return time.Now().After(i.ExpiresAt) +} + +// UserInvite contains the result of creating or regenerating an invite +type UserInvite struct { + UserInfo *UserInfo + InviteToken string + InviteExpiresAt time.Time + InviteCreatedAt time.Time +} + +// UserInviteInfo contains public information about an invite (for unauthenticated endpoint) +type UserInviteInfo struct { + Email string `json:"email"` + Name string `json:"name"` + ExpiresAt time.Time `json:"expires_at"` + Valid bool `json:"valid"` + InvitedBy string `json:"invited_by"` +} + +// NewInviteID generates a new invite ID using xid +func NewInviteID() string { + return xid.New().String() +} + +// EncryptSensitiveData encrypts the invite's sensitive fields (Email and Name) in place. +func (i *UserInviteRecord) EncryptSensitiveData(enc *crypt.FieldEncrypt) error { + if enc == nil { + return nil + } + + var err error + if i.Email != "" { + i.Email, err = enc.Encrypt(i.Email) + if err != nil { + return fmt.Errorf("encrypt email: %w", err) + } + } + + if i.Name != "" { + i.Name, err = enc.Encrypt(i.Name) + if err != nil { + return fmt.Errorf("encrypt name: %w", err) + } + } + + return nil +} + +// DecryptSensitiveData decrypts the invite's sensitive fields (Email and Name) in place. +func (i *UserInviteRecord) DecryptSensitiveData(enc *crypt.FieldEncrypt) error { + if enc == nil { + return nil + } + + var err error + if i.Email != "" { + i.Email, err = enc.Decrypt(i.Email) + if err != nil { + return fmt.Errorf("decrypt email: %w", err) + } + } + + if i.Name != "" { + i.Name, err = enc.Decrypt(i.Name) + if err != nil { + return fmt.Errorf("decrypt name: %w", err) + } + } + + return nil +} + +// Copy creates a deep copy of the UserInviteRecord +func (i *UserInviteRecord) Copy() *UserInviteRecord { + autoGroups := make([]string, len(i.AutoGroups)) + copy(autoGroups, i.AutoGroups) + + return &UserInviteRecord{ + ID: i.ID, + AccountID: i.AccountID, + Email: i.Email, + Name: i.Name, + Role: i.Role, + AutoGroups: autoGroups, + HashedToken: i.HashedToken, + ExpiresAt: i.ExpiresAt, + CreatedAt: i.CreatedAt, + CreatedBy: i.CreatedBy, + } +} diff --git a/management/server/types/user_invite_test.go b/management/server/types/user_invite_test.go new file mode 100644 index 000000000..09dae3800 --- /dev/null +++ b/management/server/types/user_invite_test.go @@ -0,0 +1,355 @@ +package types + +import ( + "crypto/sha256" + b64 "encoding/base64" + "hash/crc32" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/base62" + "github.com/netbirdio/netbird/util/crypt" +) + +func TestUserInviteRecord_TableName(t *testing.T) { + invite := UserInviteRecord{} + assert.Equal(t, "user_invites", invite.TableName()) +} + +func TestGenerateInviteToken_Success(t *testing.T) { + hashedToken, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + assert.NotEmpty(t, hashedToken) + assert.NotEmpty(t, plainToken) +} + +func TestGenerateInviteToken_Length(t *testing.T) { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + assert.Len(t, plainToken, InviteTokenLength) +} + +func TestGenerateInviteToken_Prefix(t *testing.T) { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + assert.True(t, strings.HasPrefix(plainToken, InviteTokenPrefix)) +} + +func TestGenerateInviteToken_Hashing(t *testing.T) { + hashedToken, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + + expectedHash := sha256.Sum256([]byte(plainToken)) + expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:]) + assert.Equal(t, expectedHashedToken, hashedToken) +} + +func TestGenerateInviteToken_Checksum(t *testing.T) { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + + // Extract parts + secret := plainToken[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength] + checksumStr := plainToken[len(InviteTokenPrefix)+InviteTokenSecretLength:] + + // Verify checksum + expectedChecksum := crc32.ChecksumIEEE([]byte(secret)) + actualChecksum, err := base62.Decode(checksumStr) + require.NoError(t, err) + assert.Equal(t, expectedChecksum, actualChecksum) +} + +func TestGenerateInviteToken_Uniqueness(t *testing.T) { + tokens := make(map[string]bool) + for i := 0; i < 100; i++ { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + assert.False(t, tokens[plainToken], "Token should be unique") + tokens[plainToken] = true + } +} + +func TestHashInviteToken(t *testing.T) { + token := "nbi_testtoken123456789012345678901234" + hashedToken := HashInviteToken(token) + + expectedHash := sha256.Sum256([]byte(token)) + expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:]) + assert.Equal(t, expectedHashedToken, hashedToken) +} + +func TestHashInviteToken_Consistency(t *testing.T) { + token := "nbi_testtoken123456789012345678901234" + hash1 := HashInviteToken(token) + hash2 := HashInviteToken(token) + assert.Equal(t, hash1, hash2) +} + +func TestHashInviteToken_DifferentTokens(t *testing.T) { + token1 := "nbi_testtoken123456789012345678901234" + token2 := "nbi_testtoken123456789012345678901235" + hash1 := HashInviteToken(token1) + hash2 := HashInviteToken(token2) + assert.NotEqual(t, hash1, hash2) +} + +func TestValidateInviteToken_Success(t *testing.T) { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + + err = ValidateInviteToken(plainToken) + assert.NoError(t, err) +} + +func TestValidateInviteToken_InvalidLength(t *testing.T) { + testCases := []struct { + name string + token string + }{ + {"empty", ""}, + {"too short", "nbi_abc"}, + {"too long", "nbi_" + strings.Repeat("a", 50)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateInviteToken(tc.token) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid token length") + }) + } +} + +func TestValidateInviteToken_InvalidPrefix(t *testing.T) { + // Create a token with wrong prefix but correct length + token := "xyz_" + strings.Repeat("a", 30) + "000000" + err := ValidateInviteToken(token) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid token prefix") +} + +func TestValidateInviteToken_InvalidChecksum(t *testing.T) { + // Create a token with correct format but invalid checksum + token := InviteTokenPrefix + strings.Repeat("a", InviteTokenSecretLength) + "ZZZZZZ" + err := ValidateInviteToken(token) + require.Error(t, err) + assert.Contains(t, err.Error(), "checksum") +} + +func TestValidateInviteToken_ModifiedToken(t *testing.T) { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + + // Modify one character in the secret part + modifiedToken := plainToken[:5] + "X" + plainToken[6:] + err = ValidateInviteToken(modifiedToken) + require.Error(t, err) +} + +func TestUserInviteRecord_IsExpired(t *testing.T) { + t.Run("not expired", func(t *testing.T) { + invite := &UserInviteRecord{ + ExpiresAt: time.Now().Add(time.Hour), + } + assert.False(t, invite.IsExpired()) + }) + + t.Run("expired", func(t *testing.T) { + invite := &UserInviteRecord{ + ExpiresAt: time.Now().Add(-time.Hour), + } + assert.True(t, invite.IsExpired()) + }) + + t.Run("just expired", func(t *testing.T) { + invite := &UserInviteRecord{ + ExpiresAt: time.Now().Add(-time.Second), + } + assert.True(t, invite.IsExpired()) + }) +} + +func TestNewInviteID(t *testing.T) { + id := NewInviteID() + assert.NotEmpty(t, id) + assert.Len(t, id, 20) // xid generates 20 character IDs +} + +func TestNewInviteID_Uniqueness(t *testing.T) { + ids := make(map[string]bool) + for i := 0; i < 100; i++ { + id := NewInviteID() + assert.False(t, ids[id], "ID should be unique") + ids[id] = true + } +} + +func TestUserInviteRecord_EncryptDecryptSensitiveData(t *testing.T) { + key, err := crypt.GenerateKey() + require.NoError(t, err) + fieldEncrypt, err := crypt.NewFieldEncrypt(key) + require.NoError(t, err) + + t.Run("encrypt and decrypt", func(t *testing.T) { + invite := &UserInviteRecord{ + ID: "test-invite", + AccountID: "test-account", + Email: "test@example.com", + Name: "Test User", + Role: "user", + } + + // Encrypt + err := invite.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + // Verify encrypted values are different from original + assert.NotEqual(t, "test@example.com", invite.Email) + assert.NotEqual(t, "Test User", invite.Name) + + // Decrypt + err = invite.DecryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + // Verify decrypted values match original + assert.Equal(t, "test@example.com", invite.Email) + assert.Equal(t, "Test User", invite.Name) + }) + + t.Run("encrypt empty fields", func(t *testing.T) { + invite := &UserInviteRecord{ + ID: "test-invite", + AccountID: "test-account", + Email: "", + Name: "", + Role: "user", + } + + err := invite.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + assert.Equal(t, "", invite.Email) + assert.Equal(t, "", invite.Name) + + err = invite.DecryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + assert.Equal(t, "", invite.Email) + assert.Equal(t, "", invite.Name) + }) + + t.Run("nil encryptor", func(t *testing.T) { + invite := &UserInviteRecord{ + ID: "test-invite", + AccountID: "test-account", + Email: "test@example.com", + Name: "Test User", + Role: "user", + } + + err := invite.EncryptSensitiveData(nil) + require.NoError(t, err) + assert.Equal(t, "test@example.com", invite.Email) + assert.Equal(t, "Test User", invite.Name) + + err = invite.DecryptSensitiveData(nil) + require.NoError(t, err) + assert.Equal(t, "test@example.com", invite.Email) + assert.Equal(t, "Test User", invite.Name) + }) +} + +func TestUserInviteRecord_Copy(t *testing.T) { + now := time.Now() + expiresAt := now.Add(72 * time.Hour) + + original := &UserInviteRecord{ + ID: "invite-id", + AccountID: "account-id", + Email: "test@example.com", + Name: "Test User", + Role: "user", + AutoGroups: []string{"group1", "group2"}, + HashedToken: "hashed-token", + ExpiresAt: expiresAt, + CreatedAt: now, + CreatedBy: "creator-id", + } + + copied := original.Copy() + + // Verify all fields are copied + assert.Equal(t, original.ID, copied.ID) + assert.Equal(t, original.AccountID, copied.AccountID) + assert.Equal(t, original.Email, copied.Email) + assert.Equal(t, original.Name, copied.Name) + assert.Equal(t, original.Role, copied.Role) + assert.Equal(t, original.AutoGroups, copied.AutoGroups) + assert.Equal(t, original.HashedToken, copied.HashedToken) + assert.Equal(t, original.ExpiresAt, copied.ExpiresAt) + assert.Equal(t, original.CreatedAt, copied.CreatedAt) + assert.Equal(t, original.CreatedBy, copied.CreatedBy) + + // Verify deep copy of AutoGroups (modifying copy doesn't affect original) + copied.AutoGroups[0] = "modified" + assert.NotEqual(t, original.AutoGroups[0], copied.AutoGroups[0]) + assert.Equal(t, "group1", original.AutoGroups[0]) +} + +func TestUserInviteRecord_Copy_EmptyAutoGroups(t *testing.T) { + original := &UserInviteRecord{ + ID: "invite-id", + AccountID: "account-id", + AutoGroups: []string{}, + } + + copied := original.Copy() + assert.NotNil(t, copied.AutoGroups) + assert.Len(t, copied.AutoGroups, 0) +} + +func TestUserInviteRecord_Copy_NilAutoGroups(t *testing.T) { + original := &UserInviteRecord{ + ID: "invite-id", + AccountID: "account-id", + AutoGroups: nil, + } + + copied := original.Copy() + assert.NotNil(t, copied.AutoGroups) + assert.Len(t, copied.AutoGroups, 0) +} + +func TestInviteTokenConstants(t *testing.T) { + // Verify constants are consistent + expectedLength := len(InviteTokenPrefix) + InviteTokenSecretLength + InviteTokenChecksumLength + assert.Equal(t, InviteTokenLength, expectedLength) + assert.Equal(t, 4, len(InviteTokenPrefix)) + assert.Equal(t, 30, InviteTokenSecretLength) + assert.Equal(t, 6, InviteTokenChecksumLength) + assert.Equal(t, 40, InviteTokenLength) + assert.Equal(t, 259200, DefaultInviteExpirationSeconds) // 72 hours + assert.Equal(t, 3600, MinInviteExpirationSeconds) // 1 hour +} + +func TestGenerateInviteToken_ValidatesOwnOutput(t *testing.T) { + // Generate multiple tokens and ensure they all validate + for i := 0; i < 50; i++ { + _, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + + err = ValidateInviteToken(plainToken) + assert.NoError(t, err, "Generated token should always be valid") + } +} + +func TestHashInviteToken_MatchesGeneratedHash(t *testing.T) { + hashedToken, plainToken, err := GenerateInviteToken() + require.NoError(t, err) + + // HashInviteToken should produce the same hash as GenerateInviteToken + rehashedToken := HashInviteToken(plainToken) + assert.Equal(t, hashedToken, rehashedToken) +} diff --git a/management/server/user.go b/management/server/user.go index 0a090d681..51da7a633 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" "time" + "unicode" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/auth" @@ -1453,3 +1454,368 @@ func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, init return nil } + +// CreateUserInvite creates an invite link for a new user in the embedded IdP. +// The user is NOT created until the invite is accepted. +func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + if !IsEmbeddedIdp(am.idpManager) { + return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + } + + if err := validateUserInvite(invite); err != nil { + return nil, err + } + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + // Check if user already exists in NetBird DB + existingUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, err + } + for _, user := range existingUsers { + if strings.EqualFold(user.Email, invite.Email) { + return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists") + } + } + + // Check if invite already exists for this email + existingInvite, err := am.Store.GetUserInviteByEmail(ctx, store.LockingStrengthNone, accountID, invite.Email) + if err != nil { + if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { + return nil, fmt.Errorf("failed to check existing invites: %w", err) + } + } + if existingInvite != nil { + return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email") + } + + // Calculate expiration time + if expiresIn <= 0 { + expiresIn = types.DefaultInviteExpirationSeconds + } + + if expiresIn < types.MinInviteExpirationSeconds { + return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour") + } + expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second) + + // Generate invite token + inviteID := types.NewInviteID() + hashedToken, plainToken, err := types.GenerateInviteToken() + if err != nil { + return nil, fmt.Errorf("failed to generate invite token: %w", err) + } + + // Create the invite record (no user created yet) + userInvite := &types.UserInviteRecord{ + ID: inviteID, + AccountID: accountID, + Email: invite.Email, + Name: invite.Name, + Role: invite.Role, + AutoGroups: invite.AutoGroups, + HashedToken: hashedToken, + ExpiresAt: expiresAt, + CreatedAt: time.Now().UTC(), + CreatedBy: initiatorUserID, + } + + if err := am.Store.SaveUserInvite(ctx, userInvite); err != nil { + return nil, err + } + + am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkCreated, map[string]any{"email": invite.Email}) + + return &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: inviteID, + Email: invite.Email, + Name: invite.Name, + Role: invite.Role, + AutoGroups: invite.AutoGroups, + Status: string(types.UserStatusInvited), + Issued: types.UserIssuedAPI, + }, + InviteToken: plainToken, + InviteExpiresAt: expiresAt, + }, nil +} + +// GetUserInviteInfo retrieves invite information from a token (public endpoint). +func (am *DefaultAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) { + if err := types.ValidateInviteToken(token); err != nil { + return nil, status.Errorf(status.InvalidArgument, "invalid invite token: %v", err) + } + + hashedToken := types.HashInviteToken(token) + invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthNone, hashedToken) + if err != nil { + return nil, err + } + + // Get the inviter's name + invitedBy := "" + if invite.CreatedBy != "" { + inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, invite.CreatedBy) + if err == nil && inviter != nil { + invitedBy = inviter.Name + } + } + + return &types.UserInviteInfo{ + Email: invite.Email, + Name: invite.Name, + ExpiresAt: invite.ExpiresAt, + Valid: !invite.IsExpired(), + InvitedBy: invitedBy, + }, nil +} + +// ListUserInvites returns all invites for an account. +func (am *DefaultAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) { + if !IsEmbeddedIdp(am.idpManager) { + return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + } + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + records, err := am.Store.GetAccountUserInvites(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, err + } + + invites := make([]*types.UserInvite, 0, len(records)) + for _, record := range records { + invites = append(invites, &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: record.ID, + Email: record.Email, + Name: record.Name, + Role: record.Role, + AutoGroups: record.AutoGroups, + }, + InviteExpiresAt: record.ExpiresAt, + InviteCreatedAt: record.CreatedAt, + }) + } + + return invites, nil +} + +// AcceptUserInvite accepts an invite and creates the user in both IdP and NetBird DB. +func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error { + if !IsEmbeddedIdp(am.idpManager) { + return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + } + + if password == "" { + return status.Errorf(status.InvalidArgument, "password is required") + } + + if err := validatePassword(password); err != nil { + return status.Errorf(status.InvalidArgument, "invalid password: %v", err) + } + + if err := types.ValidateInviteToken(token); err != nil { + return status.Errorf(status.InvalidArgument, "invalid invite token: %v", err) + } + + hashedToken := types.HashInviteToken(token) + invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthUpdate, hashedToken) + if err != nil { + return err + } + + if invite.IsExpired() { + return status.Errorf(status.InvalidArgument, "invite has expired") + } + + // Create user in Dex with the provided password + embeddedIdp, ok := am.idpManager.(*idp.EmbeddedIdPManager) + if !ok { + return status.Errorf(status.Internal, "failed to get embedded IdP manager") + } + + idpUser, err := embeddedIdp.CreateUserWithPassword(ctx, invite.Email, password, invite.Name) + if err != nil { + return fmt.Errorf("failed to create user in IdP: %w", err) + } + + // Create user in NetBird DB + newUser := &types.User{ + Id: idpUser.ID, + AccountID: invite.AccountID, + Role: types.StrRoleToUserRole(invite.Role), + AutoGroups: invite.AutoGroups, + Issued: types.UserIssuedAPI, + CreatedAt: time.Now().UTC(), + Email: invite.Email, + Name: invite.Name, + } + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err := transaction.SaveUser(ctx, newUser); err != nil { + return fmt.Errorf("failed to save user: %w", err) + } + if err := transaction.DeleteUserInvite(ctx, invite.ID); err != nil { + return fmt.Errorf("failed to delete invite: %w", err) + } + return nil + }) + if err != nil { + // Best-effort rollback: delete the IdP user to avoid orphaned records + if deleteErr := embeddedIdp.DeleteUser(ctx, idpUser.ID); deleteErr != nil { + log.WithContext(ctx).WithError(deleteErr).Errorf("failed to rollback IdP user %s after transaction failure", idpUser.ID) + } + return err + } + + am.StoreEvent(ctx, newUser.Id, newUser.Id, invite.AccountID, activity.UserInviteLinkAccepted, map[string]any{"email": invite.Email}) + + return nil +} + +// RegenerateUserInvite creates a new invite token for an existing invite, invalidating the previous one. +func (am *DefaultAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) { + if !IsEmbeddedIdp(am.idpManager) { + return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + } + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + // Get existing invite + existingInvite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID) + if err != nil { + return nil, err + } + + // Calculate expiration time + if expiresIn <= 0 { + expiresIn = types.DefaultInviteExpirationSeconds + } + if expiresIn < types.MinInviteExpirationSeconds { + return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour") + } + expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second) + + // Generate new invite token + hashedToken, plainToken, err := types.GenerateInviteToken() + if err != nil { + return nil, fmt.Errorf("failed to generate invite token: %w", err) + } + + // Update existing invite with new token and expiration + existingInvite.HashedToken = hashedToken + existingInvite.ExpiresAt = expiresAt + existingInvite.CreatedBy = initiatorUserID + + err = am.Store.SaveUserInvite(ctx, existingInvite) + if err != nil { + return nil, err + } + + am.StoreEvent(ctx, initiatorUserID, existingInvite.ID, accountID, activity.UserInviteLinkRegenerated, map[string]any{"email": existingInvite.Email}) + + return &types.UserInvite{ + UserInfo: &types.UserInfo{ + ID: existingInvite.ID, + Email: existingInvite.Email, + Name: existingInvite.Name, + Role: existingInvite.Role, + AutoGroups: existingInvite.AutoGroups, + Status: string(types.UserStatusInvited), + Issued: types.UserIssuedAPI, + }, + InviteToken: plainToken, + InviteExpiresAt: expiresAt, + }, nil +} + +// DeleteUserInvite deletes an existing invite by ID. +func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error { + if !IsEmbeddedIdp(am.idpManager) { + return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") + } + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + invite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID) + if err != nil { + return err + } + + if err := am.Store.DeleteUserInvite(ctx, inviteID); err != nil { + return err + } + + am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkDeleted, map[string]any{"email": invite.Email}) + + return nil +} + +const minPasswordLength = 8 + +// validatePassword checks password strength requirements: +// - Minimum 8 characters +// - At least 1 digit +// - At least 1 uppercase letter +// - At least 1 special character +func validatePassword(password string) error { + if len(password) < minPasswordLength { + return errors.New("password must be at least 8 characters long") + } + + var hasDigit, hasUpper, hasSpecial bool + for _, c := range password { + switch { + case unicode.IsDigit(c): + hasDigit = true + case unicode.IsUpper(c): + hasUpper = true + case !unicode.IsLetter(c) && !unicode.IsDigit(c): + hasSpecial = true + } + } + + var missing []string + if !hasDigit { + missing = append(missing, "one digit") + } + if !hasUpper { + missing = append(missing, "one uppercase letter") + } + if !hasSpecial { + missing = append(missing, "one special character") + } + + if len(missing) > 0 { + return errors.New("password must contain at least " + strings.Join(missing, ", ")) + } + + return nil +} diff --git a/management/server/user_invite_test.go b/management/server/user_invite_test.go new file mode 100644 index 000000000..6256ed44a --- /dev/null +++ b/management/server/user_invite_test.go @@ -0,0 +1,1010 @@ +package server + +import ( + "context" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/util/crypt" +) + +const ( + testAccountID = "testAccountID" + testAdminUserID = "testAdminUserID" + testRegularUserID = "testRegularUserID" +) + +// setupInviteTestManagerWithEmbeddedIdP creates a test manager with a real embedded IdP +// and store encryption enabled. This is required for tests that need to pass the IsEmbeddedIdp check. +func setupInviteTestManagerWithEmbeddedIdP(t *testing.T) (*DefaultAccountManager, func()) { + t.Helper() + ctx := context.Background() + + tmpDir := t.TempDir() + dexDataDir := tmpDir + "/dex" + require.NoError(t, os.MkdirAll(dexDataDir, 0700)) + + // Create test store + s, cleanup, err := store.NewTestStoreFromSQL(ctx, "", tmpDir) + require.NoError(t, err, "Error when creating store") + + // Enable encryption + key, err := crypt.GenerateKey() + require.NoError(t, err) + fieldEncrypt, err := crypt.NewFieldEncrypt(key) + require.NoError(t, err) + s.SetFieldEncrypt(fieldEncrypt) + + // Create embedded IDP config + embeddedIdPConfig := &idp.EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: idp.EmbeddedStorageConfig{ + Type: "sqlite3", + Config: idp.EmbeddedStorageTypeConfig{ + File: dexDataDir + "/dex.db", + }, + }, + } + + // Create embedded IDP manager + embeddedIdp, err := idp.NewEmbeddedIdPManager(ctx, embeddedIdPConfig, nil) + require.NoError(t, err) + + account := newAccountWithId(ctx, testAccountID, testAdminUserID, "", "admin@test.com", "Admin User", false) + account.Users[testRegularUserID] = &types.User{ + Id: testRegularUserID, + AccountID: testAccountID, + Role: types.UserRoleUser, + Email: "regular@test.com", + Name: "Regular User", + } + + err = s.SaveAccount(ctx, account) + require.NoError(t, err, "Error when saving account") + + permissionsManager := permissions.NewManager(s) + + am := DefaultAccountManager{ + Store: s, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsManager, + idpManager: embeddedIdp, + } + + cleanupFunc := func() { + _ = embeddedIdp.Stop(ctx) + cleanup() + } + + return &am, cleanupFunc +} + +func TestCreateUserInvite_Success(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + require.NotNil(t, result) + + assert.Equal(t, "newuser@test.com", result.UserInfo.Email) + assert.Equal(t, "New User", result.UserInfo.Name) + assert.Equal(t, "user", result.UserInfo.Role) + assert.Equal(t, string(types.UserStatusInvited), result.UserInfo.Status) + assert.NotEmpty(t, result.InviteToken) + assert.True(t, result.InviteExpiresAt.After(time.Now())) + + // Verify invite is stored in DB + invites, err := am.Store.GetAccountUserInvites(context.Background(), store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + assert.Len(t, invites, 1) + assert.Equal(t, "newuser@test.com", invites[0].Email) +} + +func TestCreateUserInvite_DuplicateEmail(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + // Create first invite + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Try to create duplicate invite + _, err = am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.AlreadyExists, sErr.Type()) +} + +func TestCreateUserInvite_ExistingUserEmail(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Try to invite with an email that already exists as a user + invite := &types.UserInfo{ + Email: "regular@test.com", // Already exists as a user + Name: "Duplicate User", + Role: "user", + AutoGroups: []string{}, + } + + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.UserAlreadyExists, sErr.Type()) +} + +func TestCreateUserInvite_PermissionDenied(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + // Regular user should not be able to create invites + _, err := am.CreateUserInvite(context.Background(), testAccountID, testRegularUserID, invite, 0) + require.Error(t, err) +} + +func TestCreateUserInvite_InvalidEmail(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) +} + +func TestCreateUserInvite_InvalidName(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "", + Role: "user", + AutoGroups: []string{}, + } + + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) +} + +func TestCreateUserInvite_OwnerRole(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newowner@test.com", + Name: "New Owner", + Role: "owner", + AutoGroups: []string{}, + } + + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) +} + +func TestCreateUserInvite_ExpirationTooShort(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + // Try to create with expiration less than 1 hour + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 1800) // 30 minutes + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) + assert.Contains(t, err.Error(), "at least 1 hour") +} + +func TestCreateUserInvite_CustomExpiration(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + expiresIn := 7200 // 2 hours + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, expiresIn) + require.NoError(t, err) + + // Verify expiration is approximately 2 hours from now + expectedExpiration := time.Now().Add(time.Duration(expiresIn) * time.Second) + assert.WithinDuration(t, expectedExpiration, result.InviteExpiresAt, time.Minute) +} + +func TestCreateUserInvite_WithAutoGroups(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{"group1", "group2"}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + assert.Equal(t, []string{"group1", "group2"}, result.UserInfo.AutoGroups) + + // Verify invite in DB has auto groups + invites, err := am.Store.GetAccountUserInvites(context.Background(), store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + require.Len(t, invites, 1) + assert.Equal(t, []string{"group1", "group2"}, invites[0].AutoGroups) +} + +func TestGetUserInviteInfo_Success(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite first + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Get the invite info using the token + info, err := am.GetUserInviteInfo(context.Background(), result.InviteToken) + require.NoError(t, err) + require.NotNil(t, info) + + assert.Equal(t, "newuser@test.com", info.Email) + assert.Equal(t, "New User", info.Name) + assert.True(t, info.Valid) + assert.Equal(t, "Admin User", info.InvitedBy) +} + +func TestGetUserInviteInfo_InvalidToken(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + _, err := am.GetUserInviteInfo(context.Background(), "invalid_token") + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) +} + +func TestGetUserInviteInfo_TokenNotFound(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Generate a valid token format that doesn't exist in DB + _, validToken, err := types.GenerateInviteToken() + require.NoError(t, err) + + _, err = am.GetUserInviteInfo(context.Background(), validToken) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.NotFound, sErr.Type()) +} + +func TestGetUserInviteInfo_ExpiredInvite(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite with valid expiration + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Manually set the invite to expired by updating the store directly + inviteRecord, err := am.Store.GetUserInviteByID(context.Background(), store.LockingStrengthUpdate, testAccountID, result.UserInfo.ID) + require.NoError(t, err) + inviteRecord.ExpiresAt = time.Now().Add(-time.Hour) // Set to 1 hour ago + err = am.Store.SaveUserInvite(context.Background(), inviteRecord) + require.NoError(t, err) + + // Get the invite info - should still return info but Valid should be false + info, err := am.GetUserInviteInfo(context.Background(), result.InviteToken) + require.NoError(t, err) + assert.False(t, info.Valid) +} + +func TestListUserInvites_Success(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create multiple invites + for i, email := range []string{"user1@test.com", "user2@test.com", "user3@test.com"} { + invite := &types.UserInfo{ + Email: email, + Name: "User " + string(rune('1'+i)), + Role: "user", + AutoGroups: []string{}, + } + _, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + } + + // List invites + invites, err := am.ListUserInvites(context.Background(), testAccountID, testAdminUserID) + require.NoError(t, err) + assert.Len(t, invites, 3) +} + +func TestListUserInvites_Empty(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + invites, err := am.ListUserInvites(context.Background(), testAccountID, testAdminUserID) + require.NoError(t, err) + assert.Len(t, invites, 0) +} + +func TestListUserInvites_PermissionDenied(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + _, err := am.ListUserInvites(context.Background(), testAccountID, testRegularUserID) + require.Error(t, err) +} + +func TestRegenerateUserInvite_Success(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite first + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + originalResult, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Regenerate the invite + newResult, err := am.RegenerateUserInvite(context.Background(), testAccountID, testAdminUserID, originalResult.UserInfo.ID, 0) + require.NoError(t, err) + require.NotNil(t, newResult) + + // Verify invite ID remains the same (stable ID for clients) + assert.Equal(t, originalResult.UserInfo.ID, newResult.UserInfo.ID) + + // Verify new token is different + assert.NotEqual(t, originalResult.InviteToken, newResult.InviteToken) + assert.Equal(t, "newuser@test.com", newResult.UserInfo.Email) + assert.Equal(t, "New User", newResult.UserInfo.Name) + + // Verify old token no longer works + _, err = am.GetUserInviteInfo(context.Background(), originalResult.InviteToken) + require.Error(t, err) + + // Verify new token works + info, err := am.GetUserInviteInfo(context.Background(), newResult.InviteToken) + require.NoError(t, err) + assert.Equal(t, "newuser@test.com", info.Email) +} + +func TestRegenerateUserInvite_NotFound(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + _, err := am.RegenerateUserInvite(context.Background(), testAccountID, testAdminUserID, "nonexistent-id", 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.NotFound, sErr.Type()) +} + +func TestRegenerateUserInvite_PermissionDenied(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite first + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Regular user should not be able to regenerate + _, err = am.RegenerateUserInvite(context.Background(), testAccountID, testRegularUserID, result.UserInfo.ID, 0) + require.Error(t, err) +} + +func TestDeleteUserInvite_Success(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite first + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Delete the invite + err = am.DeleteUserInvite(context.Background(), testAccountID, testAdminUserID, result.UserInfo.ID) + require.NoError(t, err) + + // Verify invite is deleted + invites, err := am.Store.GetAccountUserInvites(context.Background(), store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + assert.Len(t, invites, 0) + + // Verify token no longer works + _, err = am.GetUserInviteInfo(context.Background(), result.InviteToken) + require.Error(t, err) +} + +func TestDeleteUserInvite_NotFound(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + err := am.DeleteUserInvite(context.Background(), testAccountID, testAdminUserID, "nonexistent-id") + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.NotFound, sErr.Type()) +} + +func TestDeleteUserInvite_PermissionDenied(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite first + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Regular user should not be able to delete + err = am.DeleteUserInvite(context.Background(), testAccountID, testRegularUserID, result.UserInfo.ID) + require.Error(t, err) +} + +func TestDeleteUserInvite_WrongAccount(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Create another account + anotherAccountID := "anotherAccountID" + anotherAdminID := "anotherAdminID" + anotherAccount := newAccountWithId(context.Background(), anotherAccountID, anotherAdminID, "", "otheradmin@test.com", "Other Admin", false) + err = am.Store.SaveAccount(context.Background(), anotherAccount) + require.NoError(t, err) + + // Try to delete from wrong account + err = am.DeleteUserInvite(context.Background(), anotherAccountID, anotherAdminID, result.UserInfo.ID) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.NotFound, sErr.Type()) +} + +func TestAcceptUserInvite_Success(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Accept the invite with a valid password + err = am.AcceptUserInvite(context.Background(), result.InviteToken, "Password1!") + require.NoError(t, err) + + // Verify user is created in DB + users, err := am.Store.GetAccountUsers(context.Background(), store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + + var foundUser *types.User + for _, u := range users { + if u.Email == "newuser@test.com" { + foundUser = u + break + } + } + require.NotNil(t, foundUser, "User should be created in DB") + assert.Equal(t, "New User", foundUser.Name) + assert.Equal(t, types.UserRoleUser, foundUser.Role) + + // Verify invite is deleted + invites, err := am.Store.GetAccountUserInvites(context.Background(), store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + assert.Len(t, invites, 0) +} + +func TestAcceptUserInvite_InvalidToken(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + err := am.AcceptUserInvite(context.Background(), "invalid_token", "Password1!") + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) +} + +func TestAcceptUserInvite_TokenNotFound(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Generate a valid token format that doesn't exist in DB + _, validToken, err := types.GenerateInviteToken() + require.NoError(t, err) + + err = am.AcceptUserInvite(context.Background(), validToken, "Password1!") + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.NotFound, sErr.Type()) +} + +func TestAcceptUserInvite_ExpiredToken(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite with valid expiration + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Manually set the invite to expired by updating the store directly + inviteRecord, err := am.Store.GetUserInviteByID(context.Background(), store.LockingStrengthUpdate, testAccountID, result.UserInfo.ID) + require.NoError(t, err) + inviteRecord.ExpiresAt = time.Now().Add(-time.Hour) // Set to 1 hour ago + err = am.Store.SaveUserInvite(context.Background(), inviteRecord) + require.NoError(t, err) + + err = am.AcceptUserInvite(context.Background(), result.InviteToken, "Password1!") + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) + assert.Contains(t, err.Error(), "expired") +} + +func TestAcceptUserInvite_EmptyPassword(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + err = am.AcceptUserInvite(context.Background(), result.InviteToken, "") + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.InvalidArgument, sErr.Type()) + assert.Contains(t, err.Error(), "password is required") +} + +func TestAcceptUserInvite_WeakPassword(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + testCases := []struct { + name string + password string + expectedMsg string + }{ + {"too short", "Pass1!", "at least 8 characters"}, + {"no digit", "Password!", "one digit"}, + {"no uppercase", "password1!", "one uppercase"}, + {"no special", "Password1", "one special character"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := am.AcceptUserInvite(context.Background(), result.InviteToken, tc.password) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedMsg) + }) + } +} + +func TestValidatePassword(t *testing.T) { + testCases := []struct { + name string + password string + expectError bool + errorMsg string + }{ + {"valid password", "Password1!", false, ""}, + {"valid complex password", "MyP@ssw0rd#2024", false, ""}, + {"too short", "Pass1!", true, "at least 8 characters"}, + {"no digit", "Password!", true, "one digit"}, + {"no uppercase", "password1!", true, "one uppercase"}, + {"no special", "Password1", true, "one special character"}, + {"only lowercase", "password", true, "one digit"}, + {"no uppercase no special", "password1", true, "one uppercase"}, + {"all lowercase short", "pass", true, "at least 8 characters"}, + {"empty", "", true, "at least 8 characters"}, + {"spaces count as special", "Pass word1", false, ""}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validatePassword(tc.password) + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestInviteToken_GenerateAndValidate(t *testing.T) { + hashedToken, plainToken, err := types.GenerateInviteToken() + require.NoError(t, err) + require.NotEmpty(t, hashedToken) + require.NotEmpty(t, plainToken) + + // Validate token format + assert.Len(t, plainToken, types.InviteTokenLength) + assert.True(t, len(plainToken) > len(types.InviteTokenPrefix)) + assert.Equal(t, types.InviteTokenPrefix, plainToken[:len(types.InviteTokenPrefix)]) + + // Validate checksum + err = types.ValidateInviteToken(plainToken) + require.NoError(t, err) + + // Verify hashing is consistent + hashedAgain := types.HashInviteToken(plainToken) + assert.Equal(t, hashedToken, hashedAgain) +} + +func TestInviteToken_ValidateInvalid(t *testing.T) { + testCases := []struct { + name string + token string + }{ + {"empty", ""}, + {"too short", "nbi_abc"}, + {"wrong prefix", "xyz_123456789012345678901234567890"}, + {"invalid checksum", "nbi_123456789012345678901234567890abcdef"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := types.ValidateInviteToken(tc.token) + require.Error(t, err) + }) + } +} + +func TestUserInviteRecord_IsExpired(t *testing.T) { + // Not expired + invite := &types.UserInviteRecord{ + ExpiresAt: time.Now().Add(time.Hour), + } + assert.False(t, invite.IsExpired()) + + // Expired + invite = &types.UserInviteRecord{ + ExpiresAt: time.Now().Add(-time.Hour), + } + assert.True(t, invite.IsExpired()) +} + +func TestUserInviteRecord_Copy(t *testing.T) { + original := &types.UserInviteRecord{ + ID: "invite-id", + AccountID: "account-id", + Email: "test@example.com", + Name: "Test User", + Role: "user", + AutoGroups: []string{"group1", "group2"}, + HashedToken: "hashed-token", + ExpiresAt: time.Now().Add(time.Hour), + CreatedAt: time.Now(), + CreatedBy: "creator-id", + } + + copied := original.Copy() + + assert.Equal(t, original.ID, copied.ID) + assert.Equal(t, original.AccountID, copied.AccountID) + assert.Equal(t, original.Email, copied.Email) + assert.Equal(t, original.Name, copied.Name) + assert.Equal(t, original.Role, copied.Role) + assert.Equal(t, original.AutoGroups, copied.AutoGroups) + assert.Equal(t, original.HashedToken, copied.HashedToken) + assert.Equal(t, original.ExpiresAt, copied.ExpiresAt) + assert.Equal(t, original.CreatedAt, copied.CreatedAt) + assert.Equal(t, original.CreatedBy, copied.CreatedBy) + + // Verify deep copy of AutoGroups + copied.AutoGroups[0] = "modified" + assert.NotEqual(t, original.AutoGroups[0], copied.AutoGroups[0]) +} + +func TestCreateUserInvite_NonEmbeddedIdP(t *testing.T) { + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + require.NoError(t, err) + defer cleanup() + + account := newAccountWithId(context.Background(), testAccountID, testAdminUserID, "", "admin@test.com", "Admin User", false) + err = s.SaveAccount(context.Background(), account) + require.NoError(t, err) + + permissionsManager := permissions.NewManager(s) + + // Use nil IDP manager (non-embedded) + am := DefaultAccountManager{ + Store: s, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsManager, + idpManager: nil, + } + + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "user", + AutoGroups: []string{}, + } + + _, err = am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.Error(t, err) + + sErr, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, status.PreconditionFailed, sErr.Type()) + assert.Contains(t, err.Error(), "embedded identity provider") +} + +func TestAcceptUserInvite_WithAutoGroups(t *testing.T) { + am, cleanup := setupInviteTestManagerWithEmbeddedIdP(t) + defer cleanup() + + // Create an invite with auto groups + invite := &types.UserInfo{ + Email: "newuser@test.com", + Name: "New User", + Role: "admin", + AutoGroups: []string{"group1", "group2"}, + } + + result, err := am.CreateUserInvite(context.Background(), testAccountID, testAdminUserID, invite, 0) + require.NoError(t, err) + + // Accept the invite + err = am.AcceptUserInvite(context.Background(), result.InviteToken, "Password1!") + require.NoError(t, err) + + // Verify user has the auto groups and role + users, err := am.Store.GetAccountUsers(context.Background(), store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + + var foundUser *types.User + for _, u := range users { + if u.Email == "newuser@test.com" { + foundUser = u + break + } + } + require.NotNil(t, foundUser) + assert.Equal(t, types.UserRoleAdmin, foundUser.Role) + assert.Equal(t, []string{"group1", "group2"}, foundUser.AutoGroups) +} + +func TestUserInvite_EncryptDecryptSensitiveData(t *testing.T) { + key, err := crypt.GenerateKey() + require.NoError(t, err) + fieldEncrypt, err := crypt.NewFieldEncrypt(key) + require.NoError(t, err) + + t.Run("encrypt and decrypt", func(t *testing.T) { + invite := &types.UserInviteRecord{ + ID: "test-invite", + AccountID: "test-account", + Email: "test@example.com", + Name: "Test User", + Role: "user", + } + + // Encrypt + err := invite.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + // Verify encrypted values are different from original + assert.NotEqual(t, "test@example.com", invite.Email) + assert.NotEqual(t, "Test User", invite.Name) + + // Decrypt + err = invite.DecryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + // Verify decrypted values match original + assert.Equal(t, "test@example.com", invite.Email) + assert.Equal(t, "Test User", invite.Name) + }) + + t.Run("encrypt empty fields", func(t *testing.T) { + invite := &types.UserInviteRecord{ + ID: "test-invite", + AccountID: "test-account", + Email: "", + Name: "", + Role: "user", + } + + // Encrypt empty fields + err := invite.EncryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + // Empty strings should remain empty + assert.Equal(t, "", invite.Email) + assert.Equal(t, "", invite.Name) + + // Decrypt empty fields + err = invite.DecryptSensitiveData(fieldEncrypt) + require.NoError(t, err) + + // Should still be empty + assert.Equal(t, "", invite.Email) + assert.Equal(t, "", invite.Name) + }) + + t.Run("nil encryptor", func(t *testing.T) { + invite := &types.UserInviteRecord{ + ID: "test-invite", + AccountID: "test-account", + Email: "test@example.com", + Name: "Test User", + Role: "user", + } + + // Encrypt with nil encryptor should be no-op + err := invite.EncryptSensitiveData(nil) + require.NoError(t, err) + assert.Equal(t, "test@example.com", invite.Email) + assert.Equal(t, "Test User", invite.Name) + + // Decrypt with nil encryptor should be no-op + err = invite.DecryptSensitiveData(nil) + require.NoError(t, err) + assert.Equal(t, "test@example.com", invite.Email) + assert.Equal(t, "Test User", invite.Name) + }) +} diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index f1ff98b16..26d2387d1 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -488,6 +488,171 @@ components: - role - auto_groups - is_service_user + UserInviteCreateRequest: + type: object + description: Request to create a user invite link + properties: + email: + description: User's email address + type: string + example: user@example.com + name: + description: User's full name + type: string + example: John Doe + role: + description: User's NetBird account role + type: string + example: user + auto_groups: + description: Group IDs to auto-assign to peers registered by this user + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m0 + expires_in: + description: Invite expiration time in seconds (default 72 hours) + type: integer + example: 259200 + required: + - email + - name + - role + - auto_groups + UserInvite: + type: object + description: A user invite + properties: + id: + description: Invite ID + type: string + example: d5p7eedra0h0lt6f59hg + email: + description: User's email address + type: string + example: user@example.com + name: + description: User's full name + type: string + example: John Doe + role: + description: User's NetBird account role + type: string + example: user + auto_groups: + description: Group IDs to auto-assign to peers registered by this user + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m0 + expires_at: + description: Invite expiration time + type: string + format: date-time + example: "2024-01-25T10:00:00Z" + created_at: + description: Invite creation time + type: string + format: date-time + example: "2024-01-22T10:00:00Z" + expired: + description: Whether the invite has expired + type: boolean + example: false + invite_token: + description: The invite link to be shared with the user. Only returned when the invite is created or regenerated. + type: string + example: nbi_Xk5Lz9mP2vQwRtYu1aN3bC4dE5fGh0ABC123 + required: + - id + - email + - name + - role + - auto_groups + - expires_at + - created_at + - expired + UserInviteInfo: + type: object + description: Public information about an invite + properties: + email: + description: User's email address + type: string + example: user@example.com + name: + description: User's full name + type: string + example: John Doe + expires_at: + description: Invite expiration time + type: string + format: date-time + example: "2024-01-25T10:00:00Z" + valid: + description: Whether the invite is still valid (not expired) + type: boolean + example: true + invited_by: + description: Name of the user who sent the invite + type: string + example: Admin User + required: + - email + - name + - expires_at + - valid + - invited_by + UserInviteAcceptRequest: + type: object + description: Request to accept an invite and set password + properties: + password: + description: >- + The password the user wants to set. Must be at least 8 characters long + and contain at least one uppercase letter, one digit, and one special + character (any character that is not a letter or digit, including spaces). + type: string + format: password + minLength: 8 + pattern: '^(?=.*[0-9])(?=.*[A-Z])(?=.*[^a-zA-Z0-9]).{8,}$' + example: SecurePass123! + required: + - password + UserInviteAcceptResponse: + type: object + description: Response after accepting an invite + properties: + success: + description: Whether the invite was accepted successfully + type: boolean + example: true + required: + - success + UserInviteRegenerateRequest: + type: object + description: Request to regenerate an invite link + properties: + expires_in: + description: Invite expiration time in seconds (default 72 hours) + type: integer + example: 259200 + UserInviteRegenerateResponse: + type: object + description: Response after regenerating an invite + properties: + invite_token: + description: The new invite token + type: string + example: nbi_Xk5Lz9mP2vQwRtYu1aN3bC4dE5fGh0ABC123 + invite_expires_at: + description: New invite expiration time + type: string + format: date-time + example: "2024-01-28T10:00:00Z" + required: + - invite_token + - invite_expires_at PeerMinimum: type: object properties: @@ -2071,7 +2236,8 @@ components: "dns.zone.create", "dns.zone.update", "dns.zone.delete", "dns.zone.record.create", "dns.zone.record.update", "dns.zone.record.delete", "peer.job.create", - "user.password.change" + "user.password.change", + "user.invite.link.create", "user.invite.link.accept", "user.invite.link.regenerate", "user.invite.link.delete" ] example: route.add initiator_id: @@ -2642,6 +2808,29 @@ components: required: - user_id - email + InstanceVersionInfo: + type: object + description: Version information for NetBird components + properties: + management_current_version: + description: The current running version of the management server + type: string + example: "0.35.0" + dashboard_available_version: + description: The latest available version of the dashboard (from GitHub releases) + type: string + example: "2.10.0" + management_available_version: + description: The latest available version of the management server (from GitHub releases) + type: string + example: "0.35.0" + management_update_available: + description: Indicates if a newer management version is available + type: boolean + example: true + required: + - management_current_version + - management_update_available responses: not_found: description: Resource not found @@ -2694,6 +2883,27 @@ paths: $ref: '#/components/schemas/InstanceStatus' '500': "$ref": "#/components/responses/internal_error" + /api/instance/version: + get: + summary: Get Version Info + description: Returns version information for NetBird components including the current management server version and latest available versions from GitHub. + tags: [ Instance ] + security: + - BearerAuth: [] + - TokenAuth: [] + responses: + '200': + description: Version information + content: + application/json: + schema: + $ref: '#/components/schemas/InstanceVersionInfo' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/setup: post: summary: Setup Instance @@ -3312,6 +3522,210 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/users/invites: + get: + summary: List user invites + description: Lists all pending invites for the account. Only available when embedded IdP is enabled. + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: List of invites + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/UserInvite' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '412': + description: Precondition failed - embedded IdP is not enabled + content: { } + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a user invite + description: Creates an invite link for a new user. Only available when embedded IdP is enabled. The user is not created until they accept the invite. + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + description: User invite information + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UserInviteCreateRequest' + responses: + '200': + description: Invite created successfully + content: + application/json: + schema: + $ref: '#/components/schemas/UserInvite' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '409': + description: User or invite already exists + content: { } + '412': + description: Precondition failed - embedded IdP is not enabled + content: { } + '422': + "$ref": "#/components/responses/validation_failed" + '500': + "$ref": "#/components/responses/internal_error" + /api/users/invites/{inviteId}: + delete: + summary: Delete a user invite + description: Deletes a pending invite. Only available when embedded IdP is enabled. + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: inviteId + required: true + schema: + type: string + description: The ID of the invite to delete + responses: + '200': + description: Invite deleted successfully + content: { } + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + description: Invite not found + content: { } + '412': + description: Precondition failed - embedded IdP is not enabled + content: { } + '500': + "$ref": "#/components/responses/internal_error" + /api/users/invites/{inviteId}/regenerate: + post: + summary: Regenerate a user invite + description: Regenerates an invite link for an existing invite. Invalidates the previous token and creates a new one. + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: inviteId + required: true + schema: + type: string + description: The ID of the invite to regenerate + requestBody: + description: Regenerate options + content: + application/json: + schema: + $ref: '#/components/schemas/UserInviteRegenerateRequest' + responses: + '200': + description: Invite regenerated successfully + content: + application/json: + schema: + $ref: '#/components/schemas/UserInviteRegenerateResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + description: Invite not found + content: { } + '412': + description: Precondition failed - embedded IdP is not enabled + content: { } + '422': + "$ref": "#/components/responses/validation_failed" + '500': + "$ref": "#/components/responses/internal_error" + /api/users/invites/{token}: + get: + summary: Get invite information + description: Retrieves public information about an invite. This endpoint is unauthenticated and protected by the token itself. + tags: [ Users ] + security: [] + parameters: + - in: path + name: token + required: true + schema: + type: string + description: The invite token + responses: + '200': + description: Invite information + content: + application/json: + schema: + $ref: '#/components/schemas/UserInviteInfo' + '400': + "$ref": "#/components/responses/bad_request" + '404': + description: Invite not found or invalid token + content: { } + '500': + "$ref": "#/components/responses/internal_error" + /api/users/invites/{token}/accept: + post: + summary: Accept an invite + description: Accepts an invite and creates the user with the provided password. This endpoint is unauthenticated and protected by the token itself. + tags: [ Users ] + security: [] + parameters: + - in: path + name: token + required: true + schema: + type: string + description: The invite token + requestBody: + description: Password to set for the new user + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UserInviteAcceptRequest' + responses: + '200': + description: Invite accepted successfully + content: + application/json: + schema: + $ref: '#/components/schemas/UserInviteAcceptResponse' + '400': + "$ref": "#/components/responses/bad_request" + '404': + description: Invite not found or invalid token + content: { } + '412': + description: Precondition failed - embedded IdP is not enabled or invite expired + content: { } + '422': + "$ref": "#/components/responses/validation_failed" + '500': + "$ref": "#/components/responses/internal_error" /api/peers: get: summary: List all Peers diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 848023689..e8c044b32 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -123,6 +123,10 @@ const ( EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add" EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete" EventActivityCodeUserInvite EventActivityCode = "user.invite" + EventActivityCodeUserInviteLinkAccept EventActivityCode = "user.invite.link.accept" + EventActivityCodeUserInviteLinkCreate EventActivityCode = "user.invite.link.create" + EventActivityCodeUserInviteLinkDelete EventActivityCode = "user.invite.link.delete" + EventActivityCodeUserInviteLinkRegenerate EventActivityCode = "user.invite.link.regenerate" EventActivityCodeUserJoin EventActivityCode = "user.join" EventActivityCodeUserPasswordChange EventActivityCode = "user.password.change" EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete" @@ -870,6 +874,21 @@ type InstanceStatus struct { SetupRequired bool `json:"setup_required"` } +// InstanceVersionInfo Version information for NetBird components +type InstanceVersionInfo struct { + // DashboardAvailableVersion The latest available version of the dashboard (from GitHub releases) + DashboardAvailableVersion *string `json:"dashboard_available_version,omitempty"` + + // ManagementAvailableVersion The latest available version of the management server (from GitHub releases) + ManagementAvailableVersion *string `json:"management_available_version,omitempty"` + + // ManagementCurrentVersion The current running version of the management server + ManagementCurrentVersion string `json:"management_current_version"` + + // ManagementUpdateAvailable Indicates if a newer management version is available + ManagementUpdateAvailable bool `json:"management_update_available"` +} + // JobRequest defines model for JobRequest. type JobRequest struct { Workload WorkloadRequest `json:"workload"` @@ -2166,6 +2185,99 @@ type UserCreateRequest struct { Role string `json:"role"` } +// UserInvite A user invite +type UserInvite struct { + // AutoGroups Group IDs to auto-assign to peers registered by this user + AutoGroups []string `json:"auto_groups"` + + // CreatedAt Invite creation time + CreatedAt time.Time `json:"created_at"` + + // Email User's email address + Email string `json:"email"` + + // Expired Whether the invite has expired + Expired bool `json:"expired"` + + // ExpiresAt Invite expiration time + ExpiresAt time.Time `json:"expires_at"` + + // Id Invite ID + Id string `json:"id"` + + // InviteToken The invite link to be shared with the user. Only returned when the invite is created or regenerated. + InviteToken *string `json:"invite_token,omitempty"` + + // Name User's full name + Name string `json:"name"` + + // Role User's NetBird account role + Role string `json:"role"` +} + +// UserInviteAcceptRequest Request to accept an invite and set password +type UserInviteAcceptRequest struct { + // Password The password the user wants to set. Must be at least 8 characters long and contain at least one uppercase letter, one digit, and one special character (any character that is not a letter or digit, including spaces). + Password string `json:"password"` +} + +// UserInviteAcceptResponse Response after accepting an invite +type UserInviteAcceptResponse struct { + // Success Whether the invite was accepted successfully + Success bool `json:"success"` +} + +// UserInviteCreateRequest Request to create a user invite link +type UserInviteCreateRequest struct { + // AutoGroups Group IDs to auto-assign to peers registered by this user + AutoGroups []string `json:"auto_groups"` + + // Email User's email address + Email string `json:"email"` + + // ExpiresIn Invite expiration time in seconds (default 72 hours) + ExpiresIn *int `json:"expires_in,omitempty"` + + // Name User's full name + Name string `json:"name"` + + // Role User's NetBird account role + Role string `json:"role"` +} + +// UserInviteInfo Public information about an invite +type UserInviteInfo struct { + // Email User's email address + Email string `json:"email"` + + // ExpiresAt Invite expiration time + ExpiresAt time.Time `json:"expires_at"` + + // InvitedBy Name of the user who sent the invite + InvitedBy string `json:"invited_by"` + + // Name User's full name + Name string `json:"name"` + + // Valid Whether the invite is still valid (not expired) + Valid bool `json:"valid"` +} + +// UserInviteRegenerateRequest Request to regenerate an invite link +type UserInviteRegenerateRequest struct { + // ExpiresIn Invite expiration time in seconds (default 72 hours) + ExpiresIn *int `json:"expires_in,omitempty"` +} + +// UserInviteRegenerateResponse Response after regenerating an invite +type UserInviteRegenerateResponse struct { + // InviteExpiresAt New invite expiration time + InviteExpiresAt time.Time `json:"invite_expires_at"` + + // InviteToken The new invite token + InviteToken string `json:"invite_token"` +} + // UserPermissions defines model for UserPermissions. type UserPermissions struct { // IsRestricted Indicates whether this User's Peers view is restricted @@ -2418,6 +2530,15 @@ type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest // PostApiUsersJSONRequestBody defines body for PostApiUsers for application/json ContentType. type PostApiUsersJSONRequestBody = UserCreateRequest +// PostApiUsersInvitesJSONRequestBody defines body for PostApiUsersInvites for application/json ContentType. +type PostApiUsersInvitesJSONRequestBody = UserInviteCreateRequest + +// PostApiUsersInvitesInviteIdRegenerateJSONRequestBody defines body for PostApiUsersInvitesInviteIdRegenerate for application/json ContentType. +type PostApiUsersInvitesInviteIdRegenerateJSONRequestBody = UserInviteRegenerateRequest + +// PostApiUsersInvitesTokenAcceptJSONRequestBody defines body for PostApiUsersInvitesTokenAccept for application/json ContentType. +type PostApiUsersInvitesTokenAcceptJSONRequestBody = UserInviteAcceptRequest + // PutApiUsersUserIdJSONRequestBody defines body for PutApiUsersUserId for application/json ContentType. type PutApiUsersUserIdJSONRequestBody = UserRequest