Add user invite link feature for embedded IdP (#5157)

This commit is contained in:
Misha Bragin
2026-01-27 09:42:20 +01:00
committed by GitHub
parent 44ab454a13
commit 7d791620a6
21 changed files with 4832 additions and 2 deletions

View File

@@ -30,6 +30,12 @@ type Manager interface {
autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
AcceptUserInvite(ctx context.Context, token, password string) error
RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error)
ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error

View File

@@ -199,6 +199,11 @@ const (
UserPasswordChanged Activity = 103
UserInviteLinkCreated Activity = 104
UserInviteLinkAccepted Activity = 105
UserInviteLinkRegenerated Activity = 106
UserInviteLinkDeleted Activity = 107
AccountDeleted Activity = 99999
)
@@ -327,6 +332,11 @@ var activityMap = map[Activity]Code{
JobCreatedByUser: {"Create Job for peer", "peer.job.create"},
UserPasswordChanged: {"User password changed", "user.password.change"},
UserInviteLinkCreated: {"User invite link created", "user.invite.link.create"},
UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"},
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
}
// StringCode returns a string code of the activity

View File

@@ -68,6 +68,13 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
if err := bypass.AddBypassPath("/api/setup"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
// Public invite endpoints (tokens start with nbi_)
if err := bypass.AddBypassPath("/api/users/invites/nbi_*"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
@@ -132,6 +139,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
peers.AddEndpoints(accountManager, router, networkMapController)
users.AddEndpoints(accountManager, router)
users.AddInvitesEndpoints(accountManager, router)
users.AddPublicInvitesEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router)
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
@@ -145,6 +154,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
recordsManager.RegisterEndpoints(router, rManager)
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router)
// Mount embedded IdP handler at /oauth2 path if configured
if embeddedIdpEnabled {

View File

@@ -28,6 +28,15 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
router.HandleFunc("/setup", h.setup).Methods("POST", "OPTIONS")
}
// AddVersionEndpoint registers the authenticated version endpoint.
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) {
h := &handler{
instanceManager: instanceManager,
}
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS")
}
// getInstanceStatus returns the instance status including whether setup is required.
// This endpoint is unauthenticated.
func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
@@ -65,3 +74,29 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
Email: userData.Email,
})
}
// getVersionInfo returns version information for NetBird components.
// This endpoint requires authentication.
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) {
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)
util.WriteErrorResponse("failed to get version info", http.StatusInternalServerError, w)
return
}
resp := api.InstanceVersionInfo{
ManagementCurrentVersion: versionInfo.CurrentVersion,
ManagementUpdateAvailable: versionInfo.ManagementUpdateAvailable,
}
if versionInfo.DashboardVersion != "" {
resp.DashboardAvailableVersion = &versionInfo.DashboardVersion
}
if versionInfo.ManagementVersion != "" {
resp.ManagementAvailableVersion = &versionInfo.ManagementVersion
}
util.WriteJSONObject(r.Context(), w, resp)
}

View File

@@ -25,6 +25,7 @@ type mockInstanceManager struct {
isSetupRequired bool
isSetupRequiredFn func(ctx context.Context) (bool, error)
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error)
}
func (m *mockInstanceManager) IsSetupRequired(ctx context.Context) (bool, error) {
@@ -66,6 +67,18 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo
}, nil
}
func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) {
if m.getVersionInfoFn != nil {
return m.getVersionInfoFn(ctx)
}
return &nbinstance.VersionInfo{
CurrentVersion: "0.34.0",
DashboardVersion: "2.0.0",
ManagementVersion: "0.35.0",
ManagementUpdateAvailable: true,
}, nil
}
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
@@ -279,3 +292,44 @@ func TestSetup_ManagerError(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestGetVersionInfo_Success(t *testing.T) {
manager := &mockInstanceManager{}
router := mux.NewRouter()
AddVersionEndpoint(manager, router)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.InstanceVersionInfo
err := json.NewDecoder(rec.Body).Decode(&response)
require.NoError(t, err)
assert.Equal(t, "0.34.0", response.ManagementCurrentVersion)
assert.NotNil(t, response.DashboardAvailableVersion)
assert.Equal(t, "2.0.0", *response.DashboardAvailableVersion)
assert.NotNil(t, response.ManagementAvailableVersion)
assert.Equal(t, "0.35.0", *response.ManagementAvailableVersion)
assert.True(t, response.ManagementUpdateAvailable)
}
func TestGetVersionInfo_Error(t *testing.T) {
manager := &mockInstanceManager{
getVersionInfoFn: func(ctx context.Context) (*nbinstance.VersionInfo, error) {
return nil, errors.New("failed to fetch versions")
},
}
router := mux.NewRouter()
AddVersionEndpoint(manager, router)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}

View File

@@ -0,0 +1,263 @@
package users
import (
"encoding/json"
"errors"
"io"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// publicInviteRateLimiter limits public invite requests by IP address to prevent brute-force attacks
var publicInviteRateLimiter = middleware.NewAPIRateLimiter(&middleware.RateLimiterConfig{
RequestsPerMinute: 10, // 10 attempts per minute per IP
Burst: 5, // Allow burst of 5 requests
CleanupInterval: 10 * time.Minute,
LimiterTTL: 30 * time.Minute,
})
// toUserInviteResponse converts a UserInvite to an API response.
func toUserInviteResponse(invite *types.UserInvite) api.UserInvite {
autoGroups := invite.UserInfo.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
}
var inviteLink *string
if invite.InviteToken != "" {
inviteLink = &invite.InviteToken
}
return api.UserInvite{
Id: invite.UserInfo.ID,
Email: invite.UserInfo.Email,
Name: invite.UserInfo.Name,
Role: invite.UserInfo.Role,
AutoGroups: autoGroups,
ExpiresAt: invite.InviteExpiresAt.UTC(),
CreatedAt: invite.InviteCreatedAt.UTC(),
Expired: time.Now().After(invite.InviteExpiresAt),
InviteToken: inviteLink,
}
}
// invitesHandler handles user invite operations
type invitesHandler struct {
accountManager account.Manager
}
// AddInvitesEndpoints registers invite-related endpoints
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
h := &invitesHandler{accountManager: accountManager}
// Authenticated endpoints (require admin)
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
}
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
h := &invitesHandler{accountManager: accountManager}
// Create a subrouter for public invite endpoints with rate limiting middleware
publicRouter := router.PathPrefix("/users/invites").Subrouter()
publicRouter.Use(publicInviteRateLimiter.Middleware)
// Public endpoints (no auth required, protected by token and rate limited)
publicRouter.HandleFunc("/{token}", h.getInviteInfo).Methods("GET", "OPTIONS")
publicRouter.HandleFunc("/{token}/accept", h.acceptInvite).Methods("POST", "OPTIONS")
}
// listInvites handles GET /api/users/invites
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := make([]api.UserInvite, 0, len(invites))
for _, invite := range invites {
resp = append(resp, toUserInviteResponse(invite))
}
util.WriteJSONObject(r.Context(), w, resp)
}
// createInvite handles POST /api/users/invites
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.UserInviteCreateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
invite := &types.UserInfo{
Email: req.Email,
Name: req.Name,
Role: req.Role,
AutoGroups: req.AutoGroups,
}
expiresIn := 0
if req.ExpiresIn != nil {
expiresIn = *req.ExpiresIn
}
result, err := h.accountManager.CreateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, invite, expiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
result.InviteCreatedAt = time.Now().UTC()
resp := toUserInviteResponse(result)
util.WriteJSONObject(r.Context(), w, &resp)
}
// getInviteInfo handles GET /api/users/invites/{token}
func (h *invitesHandler) getInviteInfo(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
token := vars["token"]
if token == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
return
}
info, err := h.accountManager.GetUserInviteInfo(r.Context(), token)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
expiresAt := info.ExpiresAt.UTC()
util.WriteJSONObject(r.Context(), w, &api.UserInviteInfo{
Email: info.Email,
Name: info.Name,
ExpiresAt: expiresAt,
Valid: info.Valid,
InvitedBy: info.InvitedBy,
})
}
// acceptInvite handles POST /api/users/invites/{token}/accept
func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
token := vars["token"]
if token == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
return
}
var req api.UserInviteAcceptRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err := h.accountManager.AcceptUserInvite(r.Context(), token, req.Password)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, &api.UserInviteAcceptResponse{Success: true})
}
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
return
}
var req api.UserInviteRegenerateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
// Allow empty body (io.EOF) - expiresIn is optional
if !errors.Is(err, io.EOF) {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
}
expiresIn := 0
if req.ExpiresIn != nil {
expiresIn = *req.ExpiresIn
}
result, err := h.accountManager.RegenerateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID, expiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
expiresAt := result.InviteExpiresAt.UTC()
util.WriteJSONObject(r.Context(), w, &api.UserInviteRegenerateResponse{
InviteToken: result.InviteToken,
InviteExpiresAt: expiresAt,
})
}
// deleteInvite handles DELETE /api/users/invites/{inviteId}
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
return
}
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -0,0 +1,642 @@
package users
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
testInviteID = "test-invite-id"
testInviteToken = "nbi_testtoken123456789012345678"
testEmail = "invite@example.com"
testName = "Test User"
)
func setupInvitesTestHandler(am *mock_server.MockAccountManager) *invitesHandler {
return &invitesHandler{
accountManager: am,
}
}
func TestListInvites(t *testing.T) {
now := time.Now().UTC()
testInvites := []*types.UserInvite{
{
UserInfo: &types.UserInfo{
ID: "invite-1",
Email: "user1@example.com",
Name: "User One",
Role: "user",
AutoGroups: []string{"group-1"},
},
InviteExpiresAt: now.Add(24 * time.Hour),
InviteCreatedAt: now,
},
{
UserInfo: &types.UserInfo{
ID: "invite-2",
Email: "user2@example.com",
Name: "User Two",
Role: "admin",
AutoGroups: nil,
},
InviteExpiresAt: now.Add(-1 * time.Hour), // Expired
InviteCreatedAt: now.Add(-48 * time.Hour),
},
}
tt := []struct {
name string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
expectedCount int
}{
{
name: "successful list",
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
return testInvites, nil
},
expectedCount: 2,
},
{
name: "empty list",
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
return []*types.UserInvite{}, nil
},
expectedCount: 0,
},
{
name: "permission denied",
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
return nil, status.NewPermissionDeniedError()
},
expectedCount: 0,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
ListUserInvitesFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodGet, "/api/users/invites", nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
rr := httptest.NewRecorder()
handler.listInvites(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp []api.UserInvite
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Len(t, resp, tc.expectedCount)
}
})
}
}
func TestCreateInvite(t *testing.T) {
now := time.Now().UTC()
expiresAt := now.Add(72 * time.Hour)
tt := []struct {
name string
requestBody string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
}{
{
name: "successful create",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":["group-1"]}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: testInviteID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: invite.AutoGroups,
Status: string(types.UserStatusInvited),
},
InviteToken: testInviteToken,
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "successful create with custom expiration",
requestBody: `{"email":"test@example.com","name":"Test User","role":"admin","auto_groups":[],"expires_in":3600}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
assert.Equal(t, 3600, expiresIn)
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: testInviteID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: []string{},
Status: string(types.UserStatusInvited),
},
InviteToken: testInviteToken,
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "user already exists",
requestBody: `{"email":"existing@example.com","name":"Existing User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusConflict,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists")
},
},
{
name: "invite already exists",
requestBody: `{"email":"invited@example.com","name":"Invited User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusConflict,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email")
},
},
{
name: "permission denied",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.NewPermissionDeniedError()
},
},
{
name: "embedded IDP not enabled",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "invalid JSON",
requestBody: `{invalid json}`,
expectedStatus: http.StatusBadRequest,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
CreateUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodPost, "/api/users/invites", bytes.NewBufferString(tc.requestBody))
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
rr := httptest.NewRecorder()
handler.createInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInvite
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, testInviteID, resp.Id)
assert.NotNil(t, resp.InviteToken)
assert.NotEmpty(t, *resp.InviteToken)
}
})
}
}
func TestGetInviteInfo(t *testing.T) {
now := time.Now().UTC()
tt := []struct {
name string
token string
expectedStatus int
mockFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error)
}{
{
name: "successful get valid invite",
token: testInviteToken,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return &types.UserInviteInfo{
Email: testEmail,
Name: testName,
ExpiresAt: now.Add(24 * time.Hour),
Valid: true,
InvitedBy: "Admin User",
}, nil
},
},
{
name: "successful get expired invite",
token: testInviteToken,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return &types.UserInviteInfo{
Email: testEmail,
Name: testName,
ExpiresAt: now.Add(-24 * time.Hour),
Valid: false,
InvitedBy: "Admin User",
}, nil
},
},
{
name: "invite not found",
token: "nbi_invalidtoken1234567890123456",
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return nil, status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "invalid token format",
token: "invalid",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return nil, status.Errorf(status.InvalidArgument, "invalid invite token")
},
},
{
name: "missing token",
token: "",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
GetUserInviteInfoFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodGet, "/api/users/invites/"+tc.token, nil)
if tc.token != "" {
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
}
rr := httptest.NewRecorder()
handler.getInviteInfo(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInviteInfo
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, testEmail, resp.Email)
assert.Equal(t, testName, resp.Name)
}
})
}
}
func TestAcceptInvite(t *testing.T) {
tt := []struct {
name string
token string
requestBody string
expectedStatus int
mockFunc func(ctx context.Context, token, password string) error
}{
{
name: "successful accept",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, token, password string) error {
return nil
},
},
{
name: "invite not found",
token: "nbi_invalidtoken1234567890123456",
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "invite expired",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "invite has expired")
},
},
{
name: "embedded IDP not enabled",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "missing token",
token: "",
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
{
name: "invalid JSON",
token: testInviteToken,
requestBody: `{invalid}`,
expectedStatus: http.StatusBadRequest,
mockFunc: nil,
},
{
name: "password too short",
token: testInviteToken,
requestBody: `{"password":"Short1!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must be at least 8 characters long")
},
},
{
name: "password missing digit",
token: testInviteToken,
requestBody: `{"password":"NoDigitPass!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must contain at least one digit")
},
},
{
name: "password missing uppercase",
token: testInviteToken,
requestBody: `{"password":"nouppercase1!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must contain at least one uppercase letter")
},
},
{
name: "password missing special character",
token: testInviteToken,
requestBody: `{"password":"NoSpecial123"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must contain at least one special character")
},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
AcceptUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.token+"/accept", bytes.NewBufferString(tc.requestBody))
if tc.token != "" {
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
}
rr := httptest.NewRecorder()
handler.acceptInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInviteAcceptResponse
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.True(t, resp.Success)
}
})
}
}
func TestRegenerateInvite(t *testing.T) {
now := time.Now().UTC()
expiresAt := now.Add(72 * time.Hour)
tt := []struct {
name string
inviteID string
requestBody string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
}{
{
name: "successful regenerate with empty body",
inviteID: testInviteID,
requestBody: "",
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
assert.Equal(t, 0, expiresIn)
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: inviteID,
Email: testEmail,
},
InviteToken: "nbi_newtoken12345678901234567890",
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "successful regenerate with custom expiration",
inviteID: testInviteID,
requestBody: `{"expires_in":7200}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
assert.Equal(t, 7200, expiresIn)
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: inviteID,
Email: testEmail,
},
InviteToken: "nbi_newtoken12345678901234567890",
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "invite not found",
inviteID: "non-existent-invite",
requestBody: "",
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "permission denied",
inviteID: testInviteID,
requestBody: "",
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
return nil, status.NewPermissionDeniedError()
},
},
{
name: "missing invite ID",
inviteID: "",
requestBody: "",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
{
name: "invalid JSON should return error",
inviteID: testInviteID,
requestBody: `{invalid json}`,
expectedStatus: http.StatusBadRequest,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
RegenerateUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
var body io.Reader
if tc.requestBody != "" {
body = bytes.NewBufferString(tc.requestBody)
}
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.inviteID+"/regenerate", body)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
if tc.inviteID != "" {
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
}
rr := httptest.NewRecorder()
handler.regenerateInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInviteRegenerateResponse
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.NotEmpty(t, resp.InviteToken)
}
})
}
}
func TestDeleteInvite(t *testing.T) {
tt := []struct {
name string
inviteID string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
}{
{
name: "successful delete",
inviteID: testInviteID,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return nil
},
},
{
name: "invite not found",
inviteID: "non-existent-invite",
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "permission denied",
inviteID: testInviteID,
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return status.NewPermissionDeniedError()
},
},
{
name: "embedded IDP not enabled",
inviteID: testInviteID,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "missing invite ID",
inviteID: "",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
DeleteUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodDelete, "/api/users/invites/"+tc.inviteID, nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
if tc.inviteID != "" {
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
}
rr := httptest.NewRecorder()
handler.deleteInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
})
}
}

View File

@@ -2,10 +2,14 @@ package middleware
import (
"context"
"net"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
"github.com/netbirdio/netbird/shared/management/http/util"
)
// RateLimiterConfig holds configuration for the API rate limiter
@@ -144,3 +148,25 @@ func (rl *APIRateLimiter) Reset(key string) {
defer rl.mu.Unlock()
delete(rl.limiters, key)
}
// Middleware returns an HTTP middleware that rate limits requests by client IP.
// Returns 429 Too Many Requests if the rate limit is exceeded.
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
if !rl.Allow(clientIP) {
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
return
}
next.ServeHTTP(w, r)
})
}
// getClientIP extracts the client IP address from the request.
func getClientIP(r *http.Request) string {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}

View File

@@ -0,0 +1,158 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestAPIRateLimiter_Allow(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60, // 1 per second
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// First two requests should be allowed (burst)
assert.True(t, rl.Allow("test-key"))
assert.True(t, rl.Allow("test-key"))
// Third request should be denied (exceeded burst)
assert.False(t, rl.Allow("test-key"))
// Different key should be allowed
assert.True(t, rl.Allow("different-key"))
}
func TestAPIRateLimiter_Middleware(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60, // 1 per second
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// Create a simple handler that returns 200 OK
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Wrap with rate limiter middleware
handler := rl.Middleware(nextHandler)
// First two requests should pass (burst)
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, "request %d should be allowed", i+1)
}
// Third request should be rate limited
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusTooManyRequests, rr.Code)
}
func TestAPIRateLimiter_Middleware_DifferentIPs(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := rl.Middleware(nextHandler)
// Request from first IP
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
req1.RemoteAddr = "192.168.1.1:12345"
rr1 := httptest.NewRecorder()
handler.ServeHTTP(rr1, req1)
assert.Equal(t, http.StatusOK, rr1.Code)
// Second request from first IP should be rate limited
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
req2.RemoteAddr = "192.168.1.1:12345"
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
assert.Equal(t, http.StatusTooManyRequests, rr2.Code)
// Request from different IP should be allowed
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
req3.RemoteAddr = "192.168.1.2:12345"
rr3 := httptest.NewRecorder()
handler.ServeHTTP(rr3, req3)
assert.Equal(t, http.StatusOK, rr3.Code)
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
expected string
}{
{
name: "remote addr with port",
remoteAddr: "192.168.1.1:12345",
expected: "192.168.1.1",
},
{
name: "remote addr without port",
remoteAddr: "192.168.1.1",
expected: "192.168.1.1",
},
{
name: "IPv6 with port",
remoteAddr: "[::1]:12345",
expected: "::1",
},
{
name: "IPv6 without port",
remoteAddr: "::1",
expected: "::1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = tc.remoteAddr
assert.Equal(t, tc.expected, getClientIP(req))
})
}
}
func TestAPIRateLimiter_Reset(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// Use up the burst
assert.True(t, rl.Allow("test-key"))
assert.False(t, rl.Allow("test-key"))
// Reset the limiter
rl.Reset("test-key")
// Should be allowed again
assert.True(t, rl.Allow("test-key"))
}

View File

@@ -2,18 +2,54 @@ package instance
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/mail"
"strings"
"sync"
"time"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/version"
)
const (
// Version endpoints
managementVersionURL = "https://pkgs.netbird.io/releases/latest/version"
dashboardReleasesURL = "https://api.github.com/repos/netbirdio/dashboard/releases/latest"
// Cache TTL for version information
versionCacheTTL = 60 * time.Minute
// HTTP client timeout
httpTimeout = 5 * time.Second
)
// VersionInfo contains version information for NetBird components
type VersionInfo struct {
// CurrentVersion is the running management server version
CurrentVersion string
// DashboardVersion is the latest available dashboard version from GitHub
DashboardVersion string
// ManagementVersion is the latest available management version from GitHub
ManagementVersion string
// ManagementUpdateAvailable indicates if a newer management version is available
ManagementUpdateAvailable bool
}
// githubRelease represents a GitHub release response
type githubRelease struct {
TagName string `json:"tag_name"`
}
// Manager handles instance-level operations like initial setup.
type Manager interface {
// IsSetupRequired checks if instance setup is required.
@@ -23,6 +59,9 @@ type Manager interface {
// CreateOwnerUser creates the initial owner user in the embedded IDP.
// This should only be called when IsSetupRequired returns true.
CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error)
// GetVersionInfo returns version information for NetBird components.
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
}
// DefaultManager is the default implementation of Manager.
@@ -32,6 +71,12 @@ type DefaultManager struct {
setupRequired bool
setupMu sync.RWMutex
// Version caching
httpClient *http.Client
versionMu sync.RWMutex
cachedVersions *VersionInfo
lastVersionFetch time.Time
}
// NewManager creates a new instance manager.
@@ -43,6 +88,9 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager)
store: store,
embeddedIdpManager: embeddedIdp,
setupRequired: false,
httpClient: &http.Client{
Timeout: httpTimeout,
},
}
if embeddedIdp != nil {
@@ -134,3 +182,130 @@ func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
}
return nil
}
// GetVersionInfo returns version information for NetBird components.
func (m *DefaultManager) GetVersionInfo(ctx context.Context) (*VersionInfo, error) {
m.versionMu.RLock()
if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL {
cached := *m.cachedVersions
m.versionMu.RUnlock()
return &cached, nil
}
m.versionMu.RUnlock()
return m.fetchVersionInfo(ctx)
}
func (m *DefaultManager) fetchVersionInfo(ctx context.Context) (*VersionInfo, error) {
m.versionMu.Lock()
// Double-check after acquiring write lock
if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL {
cached := *m.cachedVersions
m.versionMu.Unlock()
return &cached, nil
}
m.versionMu.Unlock()
info := &VersionInfo{
CurrentVersion: version.NetbirdVersion(),
}
// Fetch management version from pkgs.netbird.io (plain text)
mgmtVersion, err := m.fetchPlainTextVersion(ctx, managementVersionURL)
if err != nil {
log.WithContext(ctx).Warnf("failed to fetch management version: %v", err)
} else {
info.ManagementVersion = mgmtVersion
info.ManagementUpdateAvailable = isNewerVersion(info.CurrentVersion, mgmtVersion)
}
// Fetch dashboard version from GitHub
dashVersion, err := m.fetchGitHubRelease(ctx, dashboardReleasesURL)
if err != nil {
log.WithContext(ctx).Warnf("failed to fetch dashboard version from GitHub: %v", err)
} else {
info.DashboardVersion = dashVersion
}
// Update cache
m.versionMu.Lock()
m.cachedVersions = info
m.lastVersionFetch = time.Now()
m.versionMu.Unlock()
return info, nil
}
// isNewerVersion returns true if latestVersion is greater than currentVersion
func isNewerVersion(currentVersion, latestVersion string) bool {
current, err := goversion.NewVersion(currentVersion)
if err != nil {
return false
}
latest, err := goversion.NewVersion(latestVersion)
if err != nil {
return false
}
return latest.GreaterThan(current)
}
func (m *DefaultManager) fetchPlainTextVersion(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion())
resp, err := m.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("execute request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 100))
if err != nil {
return "", fmt.Errorf("read response: %w", err)
}
return strings.TrimSpace(string(body)), nil
}
func (m *DefaultManager) fetchGitHubRelease(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion())
resp, err := m.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("execute request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var release githubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return "", fmt.Errorf("decode response: %w", err)
}
// Remove 'v' prefix if present
tag := release.TagName
if len(tag) > 0 && tag[0] == 'v' {
tag = tag[1:]
}
return tag, nil
}

View File

@@ -0,0 +1,285 @@
package instance
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockRoundTripper implements http.RoundTripper for testing
type mockRoundTripper struct {
callCount atomic.Int32
managementVersion string
dashboardVersion string
}
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
m.callCount.Add(1)
var body string
if strings.Contains(req.URL.String(), "pkgs.netbird.io") {
// Plain text response for management version
body = m.managementVersion
} else if strings.Contains(req.URL.String(), "github.com") {
// JSON response for dashboard version
jsonResp, _ := json.Marshal(githubRelease{TagName: "v" + m.dashboardVersion})
body = string(jsonResp)
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(body)),
Header: make(http.Header),
}, nil
}
func TestDefaultManager_GetVersionInfo_ReturnsCurrentVersion(t *testing.T) {
mockTransport := &mockRoundTripper{
managementVersion: "0.65.0",
dashboardVersion: "2.10.0",
}
m := &DefaultManager{
httpClient: &http.Client{Transport: mockTransport},
}
ctx := context.Background()
info, err := m.GetVersionInfo(ctx)
require.NoError(t, err)
// CurrentVersion should always be set
assert.NotEmpty(t, info.CurrentVersion)
assert.Equal(t, "0.65.0", info.ManagementVersion)
assert.Equal(t, "2.10.0", info.DashboardVersion)
assert.Equal(t, int32(2), mockTransport.callCount.Load()) // 2 calls: management + dashboard
}
func TestDefaultManager_GetVersionInfo_CachesResults(t *testing.T) {
mockTransport := &mockRoundTripper{
managementVersion: "0.65.0",
dashboardVersion: "2.10.0",
}
m := &DefaultManager{
httpClient: &http.Client{Transport: mockTransport},
}
ctx := context.Background()
// First call
info1, err := m.GetVersionInfo(ctx)
require.NoError(t, err)
assert.NotEmpty(t, info1.CurrentVersion)
assert.Equal(t, "0.65.0", info1.ManagementVersion)
initialCallCount := mockTransport.callCount.Load()
// Second call should use cache (no additional HTTP calls)
info2, err := m.GetVersionInfo(ctx)
require.NoError(t, err)
assert.Equal(t, info1.CurrentVersion, info2.CurrentVersion)
assert.Equal(t, info1.ManagementVersion, info2.ManagementVersion)
assert.Equal(t, info1.DashboardVersion, info2.DashboardVersion)
// Verify no additional HTTP calls were made (cache was used)
assert.Equal(t, initialCallCount, mockTransport.callCount.Load())
}
func TestDefaultManager_FetchGitHubRelease_ParsesTagName(t *testing.T) {
tests := []struct {
name string
tagName string
expected string
shouldError bool
}{
{
name: "tag with v prefix",
tagName: "v1.2.3",
expected: "1.2.3",
},
{
name: "tag without v prefix",
tagName: "1.2.3",
expected: "1.2.3",
},
{
name: "tag with prerelease",
tagName: "v2.0.0-beta.1",
expected: "2.0.0-beta.1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(githubRelease{TagName: tc.tagName})
}))
defer server.Close()
m := &DefaultManager{
httpClient: &http.Client{Timeout: 5 * time.Second},
}
version, err := m.fetchGitHubRelease(context.Background(), server.URL)
if tc.shouldError {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tc.expected, version)
}
})
}
}
func TestDefaultManager_FetchGitHubRelease_HandlesErrors(t *testing.T) {
tests := []struct {
name string
statusCode int
body string
}{
{
name: "not found",
statusCode: http.StatusNotFound,
body: `{"message": "Not Found"}`,
},
{
name: "rate limited",
statusCode: http.StatusForbidden,
body: `{"message": "API rate limit exceeded"}`,
},
{
name: "server error",
statusCode: http.StatusInternalServerError,
body: `{"message": "Internal Server Error"}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tc.statusCode)
_, _ = w.Write([]byte(tc.body))
}))
defer server.Close()
m := &DefaultManager{
httpClient: &http.Client{Timeout: 5 * time.Second},
}
_, err := m.fetchGitHubRelease(context.Background(), server.URL)
assert.Error(t, err)
})
}
}
func TestDefaultManager_FetchGitHubRelease_InvalidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{invalid json}`))
}))
defer server.Close()
m := &DefaultManager{
httpClient: &http.Client{Timeout: 5 * time.Second},
}
_, err := m.fetchGitHubRelease(context.Background(), server.URL)
assert.Error(t, err)
}
func TestDefaultManager_FetchGitHubRelease_ContextCancellation(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(1 * time.Second)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(githubRelease{TagName: "v1.0.0"})
}))
defer server.Close()
m := &DefaultManager{
httpClient: &http.Client{Timeout: 5 * time.Second},
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := m.fetchGitHubRelease(ctx, server.URL)
assert.Error(t, err)
}
func TestIsNewerVersion(t *testing.T) {
tests := []struct {
name string
currentVersion string
latestVersion string
expected bool
}{
{
name: "latest is newer - minor version",
currentVersion: "0.64.1",
latestVersion: "0.65.0",
expected: true,
},
{
name: "latest is newer - patch version",
currentVersion: "0.64.1",
latestVersion: "0.64.2",
expected: true,
},
{
name: "latest is newer - major version",
currentVersion: "0.64.1",
latestVersion: "1.0.0",
expected: true,
},
{
name: "versions are equal",
currentVersion: "0.64.1",
latestVersion: "0.64.1",
expected: false,
},
{
name: "current is newer - minor version",
currentVersion: "0.65.0",
latestVersion: "0.64.1",
expected: false,
},
{
name: "current is newer - patch version",
currentVersion: "0.64.2",
latestVersion: "0.64.1",
expected: false,
},
{
name: "development version",
currentVersion: "development",
latestVersion: "0.65.0",
expected: false,
},
{
name: "invalid latest version",
currentVersion: "0.64.1",
latestVersion: "invalid",
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := isNewerVersion(tc.currentVersion, tc.latestVersion)
assert.Equal(t, tc.expected, result)
})
}
}

View File

@@ -139,6 +139,12 @@ type MockAccountManager struct {
CreatePeerJobFunc func(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobsFunc func(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByIDFunc func(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
CreateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
AcceptUserInviteFunc func(ctx context.Context, token, password string) error
RegenerateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
GetUserInviteInfoFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error)
ListUserInvitesFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
}
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
@@ -713,6 +719,48 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
}
func (am *MockAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
if am.CreateUserInviteFunc != nil {
return am.CreateUserInviteFunc(ctx, accountID, initiatorUserID, invite, expiresIn)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateUserInvite is not implemented")
}
func (am *MockAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error {
if am.AcceptUserInviteFunc != nil {
return am.AcceptUserInviteFunc(ctx, token, password)
}
return status.Errorf(codes.Unimplemented, "method AcceptUserInvite is not implemented")
}
func (am *MockAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
if am.RegenerateUserInviteFunc != nil {
return am.RegenerateUserInviteFunc(ctx, accountID, initiatorUserID, inviteID, expiresIn)
}
return nil, status.Errorf(codes.Unimplemented, "method RegenerateUserInvite is not implemented")
}
func (am *MockAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) {
if am.GetUserInviteInfoFunc != nil {
return am.GetUserInviteInfoFunc(ctx, token)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUserInviteInfo is not implemented")
}
func (am *MockAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
if am.ListUserInvitesFunc != nil {
return am.ListUserInvitesFunc(ctx, accountID, initiatorUserID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListUserInvites is not implemented")
}
func (am *MockAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
if am.DeleteUserInviteFunc != nil {
return am.DeleteUserInviteFunc(ctx, accountID, initiatorUserID, inviteID)
}
return status.Errorf(codes.Unimplemented, "method DeleteUserInvite is not implemented")
}
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
if am.GetAccountIDFromUserAuthFunc != nil {
return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)

View File

@@ -126,7 +126,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{},
)
if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -815,6 +815,130 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
return &user, nil
}
// SaveUserInvite saves a user invite to the database
func (s *SqlStore) SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error {
inviteCopy := invite.Copy()
if err := inviteCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt invite: %w", err)
}
result := s.db.Save(inviteCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save user invite to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save user invite to store")
}
return nil
}
// GetUserInviteByID retrieves a user invite by its ID and account ID
func (s *SqlStore) GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var invite types.UserInviteRecord
result := tx.Where("account_id = ?", accountID).Take(&invite, idQueryCondition, inviteID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user invite not found")
}
log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get user invite from store")
}
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt invite: %w", err)
}
return &invite, nil
}
// GetUserInviteByHashedToken retrieves a user invite by its hashed token
func (s *SqlStore) GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var invite types.UserInviteRecord
result := tx.Take(&invite, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user invite not found")
}
log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get user invite from store")
}
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt invite: %w", err)
}
return &invite, nil
}
// GetUserInviteByEmail retrieves a user invite by account ID and email.
// Since email is encrypted with random IVs, we fetch all invites for the account
// and compare emails in memory after decryption.
func (s *SqlStore) GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var invites []*types.UserInviteRecord
result := tx.Find(&invites, "account_id = ?", accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get user invites from store")
}
for _, invite := range invites {
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt invite: %w", err)
}
if strings.EqualFold(invite.Email, email) {
return invite, nil
}
}
return nil, status.Errorf(status.NotFound, "user invite not found for email")
}
// GetAccountUserInvites retrieves all user invites for an account
func (s *SqlStore) GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var invites []*types.UserInviteRecord
result := tx.Find(&invites, "account_id = ?", accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get user invites from store")
}
for _, invite := range invites {
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt invite: %w", err)
}
}
return invites, nil
}
// DeleteUserInvite deletes a user invite by its ID
func (s *SqlStore) DeleteUserInvite(ctx context.Context, inviteID string) error {
result := s.db.Delete(&types.UserInviteRecord{}, idQueryCondition, inviteID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete user invite from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete user invite from store")
}
return nil
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {

View File

@@ -0,0 +1,520 @@
package store
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/types"
)
func TestSqlStore_SaveUserInvite(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-1",
AccountID: "account-1",
Email: "test@example.com",
Name: "Test User",
Role: "user",
AutoGroups: []string{"group-1", "group-2"},
HashedToken: "hashed-token-123",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Verify the invite was saved
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
assert.Equal(t, invite.Email, retrieved.Email)
assert.Equal(t, invite.Name, retrieved.Name)
assert.Equal(t, invite.Role, retrieved.Role)
assert.Equal(t, invite.AutoGroups, retrieved.AutoGroups)
assert.Equal(t, invite.CreatedBy, retrieved.CreatedBy)
})
}
func TestSqlStore_SaveUserInvite_Update(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-update",
AccountID: "account-1",
Email: "test@example.com",
Name: "Test User",
Role: "user",
AutoGroups: []string{"group-1"},
HashedToken: "hashed-token-123",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Update the invite with a new token
invite.HashedToken = "new-hashed-token"
invite.ExpiresAt = time.Now().Add(24 * time.Hour)
err = store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Verify the update
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, "new-hashed-token", retrieved.HashedToken)
})
}
func TestSqlStore_GetUserInviteByID(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-get-by-id",
AccountID: "account-1",
Email: "getbyid@example.com",
Name: "Get By ID User",
Role: "admin",
AutoGroups: []string{},
HashedToken: "hashed-token-get",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Get by ID - success
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
assert.Equal(t, invite.Email, retrieved.Email)
// Get by ID - wrong account
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, "wrong-account", invite.ID)
assert.Error(t, err)
// Get by ID - not found
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, "non-existent")
assert.Error(t, err)
})
}
func TestSqlStore_GetUserInviteByHashedToken(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-get-by-token",
AccountID: "account-1",
Email: "getbytoken@example.com",
Name: "Get By Token User",
Role: "user",
AutoGroups: []string{"group-1"},
HashedToken: "unique-hashed-token-456",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Get by hashed token - success
retrieved, err := store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, invite.HashedToken)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
assert.Equal(t, invite.Email, retrieved.Email)
// Get by hashed token - not found
_, err = store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, "non-existent-token")
assert.Error(t, err)
})
}
func TestSqlStore_GetUserInviteByEmail(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-get-by-email",
AccountID: "account-email-test",
Email: "unique-email@example.com",
Name: "Get By Email User",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-email",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Get by email - success
retrieved, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, invite.Email)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
// Get by email - case insensitive
retrieved, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "UNIQUE-EMAIL@EXAMPLE.COM")
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
// Get by email - wrong account
_, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, "wrong-account", invite.Email)
assert.Error(t, err)
// Get by email - not found
_, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "nonexistent@example.com")
assert.Error(t, err)
})
}
func TestSqlStore_GetAccountUserInvites(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
accountID := "account-list-invites"
invites := []*types.UserInviteRecord{
{
ID: "invite-list-1",
AccountID: accountID,
Email: "user1@example.com",
Name: "User One",
Role: "user",
AutoGroups: []string{"group-1"},
HashedToken: "hashed-token-list-1",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
},
{
ID: "invite-list-2",
AccountID: accountID,
Email: "user2@example.com",
Name: "User Two",
Role: "admin",
AutoGroups: []string{"group-2"},
HashedToken: "hashed-token-list-2",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
},
{
ID: "invite-list-3",
AccountID: "different-account",
Email: "user3@example.com",
Name: "User Three",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-list-3",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
},
}
for _, invite := range invites {
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
}
// Get all invites for the account
retrieved, err := store.GetAccountUserInvites(ctx, LockingStrengthNone, accountID)
require.NoError(t, err)
assert.Len(t, retrieved, 2)
// Verify the invites belong to the correct account
for _, invite := range retrieved {
assert.Equal(t, accountID, invite.AccountID)
}
// Get invites for account with no invites
retrieved, err = store.GetAccountUserInvites(ctx, LockingStrengthNone, "empty-account")
require.NoError(t, err)
assert.Len(t, retrieved, 0)
})
}
func TestSqlStore_DeleteUserInvite(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-delete",
AccountID: "account-delete-test",
Email: "delete@example.com",
Name: "Delete User",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-delete",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Verify invite exists
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
// Delete the invite
err = store.DeleteUserInvite(ctx, invite.ID)
require.NoError(t, err)
// Verify invite is deleted
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
assert.Error(t, err)
})
}
func TestSqlStore_UserInvite_EncryptedFields(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-encrypted",
AccountID: "account-encrypted",
Email: "sensitive-email@example.com",
Name: "Sensitive Name",
Role: "user",
AutoGroups: []string{"group-1"},
HashedToken: "hashed-token-encrypted",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Retrieve and verify decryption works
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, "sensitive-email@example.com", retrieved.Email)
assert.Equal(t, "Sensitive Name", retrieved.Name)
})
}
func TestSqlStore_DeleteUserInvite_NonExistent(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
// Deleting a non-existent invite should not return an error
err := store.DeleteUserInvite(ctx, "non-existent-invite-id")
require.NoError(t, err)
})
}
func TestSqlStore_UserInvite_SameEmailDifferentAccounts(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
email := "shared-email@example.com"
// Create invite in first account
invite1 := &types.UserInviteRecord{
ID: "invite-account1",
AccountID: "account-1",
Email: email,
Name: "User Account 1",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-account1",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-1",
}
// Create invite in second account with same email
invite2 := &types.UserInviteRecord{
ID: "invite-account2",
AccountID: "account-2",
Email: email,
Name: "User Account 2",
Role: "admin",
AutoGroups: []string{"group-1"},
HashedToken: "hashed-token-account2",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-2",
}
err := store.SaveUserInvite(ctx, invite1)
require.NoError(t, err)
err = store.SaveUserInvite(ctx, invite2)
require.NoError(t, err)
// Verify each account gets the correct invite by email
retrieved1, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-1", email)
require.NoError(t, err)
assert.Equal(t, "invite-account1", retrieved1.ID)
assert.Equal(t, "User Account 1", retrieved1.Name)
retrieved2, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-2", email)
require.NoError(t, err)
assert.Equal(t, "invite-account2", retrieved2.ID)
assert.Equal(t, "User Account 2", retrieved2.Name)
})
}
func TestSqlStore_UserInvite_LockingStrength(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-locking",
AccountID: "account-locking",
Email: "locking@example.com",
Name: "Locking Test User",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-locking",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Test with different locking strengths
lockStrengths := []LockingStrength{LockingStrengthNone, LockingStrengthShare, LockingStrengthUpdate}
for _, strength := range lockStrengths {
retrieved, err := store.GetUserInviteByID(ctx, strength, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
retrieved, err = store.GetUserInviteByHashedToken(ctx, strength, invite.HashedToken)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
retrieved, err = store.GetUserInviteByEmail(ctx, strength, invite.AccountID, invite.Email)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
invites, err := store.GetAccountUserInvites(ctx, strength, invite.AccountID)
require.NoError(t, err)
assert.Len(t, invites, 1)
}
})
}
func TestSqlStore_UserInvite_EmptyAutoGroups(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
// Test with nil AutoGroups
invite := &types.UserInviteRecord{
ID: "invite-nil-autogroups",
AccountID: "account-autogroups",
Email: "nilgroups@example.com",
Name: "Nil Groups User",
Role: "user",
AutoGroups: nil,
HashedToken: "hashed-token-nil",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
// Should return empty slice or nil, both are acceptable
assert.Empty(t, retrieved.AutoGroups)
})
}
func TestSqlStore_UserInvite_TimestampPrecision(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
now := time.Now().UTC().Truncate(time.Millisecond)
expiresAt := now.Add(72 * time.Hour)
invite := &types.UserInviteRecord{
ID: "invite-timestamp",
AccountID: "account-timestamp",
Email: "timestamp@example.com",
Name: "Timestamp User",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-timestamp",
ExpiresAt: expiresAt,
CreatedAt: now,
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
// Verify timestamps are preserved (within reasonable precision)
assert.WithinDuration(t, now, retrieved.CreatedAt, time.Second)
assert.WithinDuration(t, expiresAt, retrieved.ExpiresAt, time.Second)
})
}

View File

@@ -92,6 +92,13 @@ type Store interface {
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error
GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error)
GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error)
GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error)
GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error)
DeleteUserInvite(ctx context.Context, inviteID string) error
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)

View File

@@ -0,0 +1,201 @@
package types
import (
"crypto/sha256"
b64 "encoding/base64"
"fmt"
"hash/crc32"
"strings"
"time"
b "github.com/hashicorp/go-secure-stdlib/base62"
"github.com/rs/xid"
"github.com/netbirdio/netbird/base62"
"github.com/netbirdio/netbird/util/crypt"
)
const (
// InviteTokenPrefix is the prefix for invite tokens
InviteTokenPrefix = "nbi_"
// InviteTokenSecretLength is the length of the random secret part
InviteTokenSecretLength = 30
// InviteTokenChecksumLength is the length of the encoded checksum
InviteTokenChecksumLength = 6
// InviteTokenLength is the total length of the token (4 + 30 + 6 = 40)
InviteTokenLength = 40
// DefaultInviteExpirationSeconds is the default expiration time for invites (72 hours)
DefaultInviteExpirationSeconds = 259200
// MinInviteExpirationSeconds is the minimum expiration time for invites (1 hour)
MinInviteExpirationSeconds = 3600
)
// UserInviteRecord represents an invitation for a user to set up their account (database model)
type UserInviteRecord struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index;not null"`
Email string `gorm:"index;not null"`
Name string `gorm:"not null"`
Role string `gorm:"not null"`
AutoGroups []string `gorm:"serializer:json"`
HashedToken string `gorm:"index;not null"` // SHA-256 hash of the token (base64 encoded)
ExpiresAt time.Time `gorm:"not null"`
CreatedAt time.Time `gorm:"not null"`
CreatedBy string `gorm:"not null"`
}
// TableName returns the table name for GORM
func (UserInviteRecord) TableName() string {
return "user_invites"
}
// GenerateInviteToken creates a new invite token with the format: nbi_<secret><checksum>
// Returns the hashed token (for storage) and the plain token (to give to the user)
func GenerateInviteToken() (hashedToken string, plainToken string, err error) {
secret, err := b.Random(InviteTokenSecretLength)
if err != nil {
return "", "", fmt.Errorf("failed to generate random secret: %w", err)
}
checksum := crc32.ChecksumIEEE([]byte(secret))
encodedChecksum := base62.Encode(checksum)
// Left-pad with '0' to ensure exactly 6 characters (fmt.Sprintf %s pads with spaces which breaks base62.Decode)
paddedChecksum := encodedChecksum
if len(paddedChecksum) < InviteTokenChecksumLength {
paddedChecksum = strings.Repeat("0", InviteTokenChecksumLength-len(paddedChecksum)) + paddedChecksum
}
plainToken = InviteTokenPrefix + secret + paddedChecksum
hash := sha256.Sum256([]byte(plainToken))
hashedToken = b64.StdEncoding.EncodeToString(hash[:])
return hashedToken, plainToken, nil
}
// HashInviteToken creates a SHA-256 hash of the token (base64 encoded)
func HashInviteToken(token string) string {
hash := sha256.Sum256([]byte(token))
return b64.StdEncoding.EncodeToString(hash[:])
}
// ValidateInviteToken validates the token format and checksum.
// Returns an error if the token is invalid.
func ValidateInviteToken(token string) error {
if len(token) != InviteTokenLength {
return fmt.Errorf("invalid token length")
}
prefix := token[:len(InviteTokenPrefix)]
if prefix != InviteTokenPrefix {
return fmt.Errorf("invalid token prefix")
}
secret := token[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength]
encodedChecksum := token[len(InviteTokenPrefix)+InviteTokenSecretLength:]
verificationChecksum, err := base62.Decode(encodedChecksum)
if err != nil {
return fmt.Errorf("checksum decoding failed: %w", err)
}
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
if secretChecksum != verificationChecksum {
return fmt.Errorf("checksum does not match")
}
return nil
}
// IsExpired checks if the invite has expired
func (i *UserInviteRecord) IsExpired() bool {
return time.Now().After(i.ExpiresAt)
}
// UserInvite contains the result of creating or regenerating an invite
type UserInvite struct {
UserInfo *UserInfo
InviteToken string
InviteExpiresAt time.Time
InviteCreatedAt time.Time
}
// UserInviteInfo contains public information about an invite (for unauthenticated endpoint)
type UserInviteInfo struct {
Email string `json:"email"`
Name string `json:"name"`
ExpiresAt time.Time `json:"expires_at"`
Valid bool `json:"valid"`
InvitedBy string `json:"invited_by"`
}
// NewInviteID generates a new invite ID using xid
func NewInviteID() string {
return xid.New().String()
}
// EncryptSensitiveData encrypts the invite's sensitive fields (Email and Name) in place.
func (i *UserInviteRecord) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
var err error
if i.Email != "" {
i.Email, err = enc.Encrypt(i.Email)
if err != nil {
return fmt.Errorf("encrypt email: %w", err)
}
}
if i.Name != "" {
i.Name, err = enc.Encrypt(i.Name)
if err != nil {
return fmt.Errorf("encrypt name: %w", err)
}
}
return nil
}
// DecryptSensitiveData decrypts the invite's sensitive fields (Email and Name) in place.
func (i *UserInviteRecord) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
var err error
if i.Email != "" {
i.Email, err = enc.Decrypt(i.Email)
if err != nil {
return fmt.Errorf("decrypt email: %w", err)
}
}
if i.Name != "" {
i.Name, err = enc.Decrypt(i.Name)
if err != nil {
return fmt.Errorf("decrypt name: %w", err)
}
}
return nil
}
// Copy creates a deep copy of the UserInviteRecord
func (i *UserInviteRecord) Copy() *UserInviteRecord {
autoGroups := make([]string, len(i.AutoGroups))
copy(autoGroups, i.AutoGroups)
return &UserInviteRecord{
ID: i.ID,
AccountID: i.AccountID,
Email: i.Email,
Name: i.Name,
Role: i.Role,
AutoGroups: autoGroups,
HashedToken: i.HashedToken,
ExpiresAt: i.ExpiresAt,
CreatedAt: i.CreatedAt,
CreatedBy: i.CreatedBy,
}
}

View File

@@ -0,0 +1,355 @@
package types
import (
"crypto/sha256"
b64 "encoding/base64"
"hash/crc32"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/base62"
"github.com/netbirdio/netbird/util/crypt"
)
func TestUserInviteRecord_TableName(t *testing.T) {
invite := UserInviteRecord{}
assert.Equal(t, "user_invites", invite.TableName())
}
func TestGenerateInviteToken_Success(t *testing.T) {
hashedToken, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
assert.NotEmpty(t, hashedToken)
assert.NotEmpty(t, plainToken)
}
func TestGenerateInviteToken_Length(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
assert.Len(t, plainToken, InviteTokenLength)
}
func TestGenerateInviteToken_Prefix(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
assert.True(t, strings.HasPrefix(plainToken, InviteTokenPrefix))
}
func TestGenerateInviteToken_Hashing(t *testing.T) {
hashedToken, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
expectedHash := sha256.Sum256([]byte(plainToken))
expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:])
assert.Equal(t, expectedHashedToken, hashedToken)
}
func TestGenerateInviteToken_Checksum(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
// Extract parts
secret := plainToken[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength]
checksumStr := plainToken[len(InviteTokenPrefix)+InviteTokenSecretLength:]
// Verify checksum
expectedChecksum := crc32.ChecksumIEEE([]byte(secret))
actualChecksum, err := base62.Decode(checksumStr)
require.NoError(t, err)
assert.Equal(t, expectedChecksum, actualChecksum)
}
func TestGenerateInviteToken_Uniqueness(t *testing.T) {
tokens := make(map[string]bool)
for i := 0; i < 100; i++ {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
assert.False(t, tokens[plainToken], "Token should be unique")
tokens[plainToken] = true
}
}
func TestHashInviteToken(t *testing.T) {
token := "nbi_testtoken123456789012345678901234"
hashedToken := HashInviteToken(token)
expectedHash := sha256.Sum256([]byte(token))
expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:])
assert.Equal(t, expectedHashedToken, hashedToken)
}
func TestHashInviteToken_Consistency(t *testing.T) {
token := "nbi_testtoken123456789012345678901234"
hash1 := HashInviteToken(token)
hash2 := HashInviteToken(token)
assert.Equal(t, hash1, hash2)
}
func TestHashInviteToken_DifferentTokens(t *testing.T) {
token1 := "nbi_testtoken123456789012345678901234"
token2 := "nbi_testtoken123456789012345678901235"
hash1 := HashInviteToken(token1)
hash2 := HashInviteToken(token2)
assert.NotEqual(t, hash1, hash2)
}
func TestValidateInviteToken_Success(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
err = ValidateInviteToken(plainToken)
assert.NoError(t, err)
}
func TestValidateInviteToken_InvalidLength(t *testing.T) {
testCases := []struct {
name string
token string
}{
{"empty", ""},
{"too short", "nbi_abc"},
{"too long", "nbi_" + strings.Repeat("a", 50)},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ValidateInviteToken(tc.token)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid token length")
})
}
}
func TestValidateInviteToken_InvalidPrefix(t *testing.T) {
// Create a token with wrong prefix but correct length
token := "xyz_" + strings.Repeat("a", 30) + "000000"
err := ValidateInviteToken(token)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid token prefix")
}
func TestValidateInviteToken_InvalidChecksum(t *testing.T) {
// Create a token with correct format but invalid checksum
token := InviteTokenPrefix + strings.Repeat("a", InviteTokenSecretLength) + "ZZZZZZ"
err := ValidateInviteToken(token)
require.Error(t, err)
assert.Contains(t, err.Error(), "checksum")
}
func TestValidateInviteToken_ModifiedToken(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
// Modify one character in the secret part
modifiedToken := plainToken[:5] + "X" + plainToken[6:]
err = ValidateInviteToken(modifiedToken)
require.Error(t, err)
}
func TestUserInviteRecord_IsExpired(t *testing.T) {
t.Run("not expired", func(t *testing.T) {
invite := &UserInviteRecord{
ExpiresAt: time.Now().Add(time.Hour),
}
assert.False(t, invite.IsExpired())
})
t.Run("expired", func(t *testing.T) {
invite := &UserInviteRecord{
ExpiresAt: time.Now().Add(-time.Hour),
}
assert.True(t, invite.IsExpired())
})
t.Run("just expired", func(t *testing.T) {
invite := &UserInviteRecord{
ExpiresAt: time.Now().Add(-time.Second),
}
assert.True(t, invite.IsExpired())
})
}
func TestNewInviteID(t *testing.T) {
id := NewInviteID()
assert.NotEmpty(t, id)
assert.Len(t, id, 20) // xid generates 20 character IDs
}
func TestNewInviteID_Uniqueness(t *testing.T) {
ids := make(map[string]bool)
for i := 0; i < 100; i++ {
id := NewInviteID()
assert.False(t, ids[id], "ID should be unique")
ids[id] = true
}
}
func TestUserInviteRecord_EncryptDecryptSensitiveData(t *testing.T) {
key, err := crypt.GenerateKey()
require.NoError(t, err)
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
require.NoError(t, err)
t.Run("encrypt and decrypt", func(t *testing.T) {
invite := &UserInviteRecord{
ID: "test-invite",
AccountID: "test-account",
Email: "test@example.com",
Name: "Test User",
Role: "user",
}
// Encrypt
err := invite.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
// Verify encrypted values are different from original
assert.NotEqual(t, "test@example.com", invite.Email)
assert.NotEqual(t, "Test User", invite.Name)
// Decrypt
err = invite.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
// Verify decrypted values match original
assert.Equal(t, "test@example.com", invite.Email)
assert.Equal(t, "Test User", invite.Name)
})
t.Run("encrypt empty fields", func(t *testing.T) {
invite := &UserInviteRecord{
ID: "test-invite",
AccountID: "test-account",
Email: "",
Name: "",
Role: "user",
}
err := invite.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", invite.Email)
assert.Equal(t, "", invite.Name)
err = invite.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", invite.Email)
assert.Equal(t, "", invite.Name)
})
t.Run("nil encryptor", func(t *testing.T) {
invite := &UserInviteRecord{
ID: "test-invite",
AccountID: "test-account",
Email: "test@example.com",
Name: "Test User",
Role: "user",
}
err := invite.EncryptSensitiveData(nil)
require.NoError(t, err)
assert.Equal(t, "test@example.com", invite.Email)
assert.Equal(t, "Test User", invite.Name)
err = invite.DecryptSensitiveData(nil)
require.NoError(t, err)
assert.Equal(t, "test@example.com", invite.Email)
assert.Equal(t, "Test User", invite.Name)
})
}
func TestUserInviteRecord_Copy(t *testing.T) {
now := time.Now()
expiresAt := now.Add(72 * time.Hour)
original := &UserInviteRecord{
ID: "invite-id",
AccountID: "account-id",
Email: "test@example.com",
Name: "Test User",
Role: "user",
AutoGroups: []string{"group1", "group2"},
HashedToken: "hashed-token",
ExpiresAt: expiresAt,
CreatedAt: now,
CreatedBy: "creator-id",
}
copied := original.Copy()
// Verify all fields are copied
assert.Equal(t, original.ID, copied.ID)
assert.Equal(t, original.AccountID, copied.AccountID)
assert.Equal(t, original.Email, copied.Email)
assert.Equal(t, original.Name, copied.Name)
assert.Equal(t, original.Role, copied.Role)
assert.Equal(t, original.AutoGroups, copied.AutoGroups)
assert.Equal(t, original.HashedToken, copied.HashedToken)
assert.Equal(t, original.ExpiresAt, copied.ExpiresAt)
assert.Equal(t, original.CreatedAt, copied.CreatedAt)
assert.Equal(t, original.CreatedBy, copied.CreatedBy)
// Verify deep copy of AutoGroups (modifying copy doesn't affect original)
copied.AutoGroups[0] = "modified"
assert.NotEqual(t, original.AutoGroups[0], copied.AutoGroups[0])
assert.Equal(t, "group1", original.AutoGroups[0])
}
func TestUserInviteRecord_Copy_EmptyAutoGroups(t *testing.T) {
original := &UserInviteRecord{
ID: "invite-id",
AccountID: "account-id",
AutoGroups: []string{},
}
copied := original.Copy()
assert.NotNil(t, copied.AutoGroups)
assert.Len(t, copied.AutoGroups, 0)
}
func TestUserInviteRecord_Copy_NilAutoGroups(t *testing.T) {
original := &UserInviteRecord{
ID: "invite-id",
AccountID: "account-id",
AutoGroups: nil,
}
copied := original.Copy()
assert.NotNil(t, copied.AutoGroups)
assert.Len(t, copied.AutoGroups, 0)
}
func TestInviteTokenConstants(t *testing.T) {
// Verify constants are consistent
expectedLength := len(InviteTokenPrefix) + InviteTokenSecretLength + InviteTokenChecksumLength
assert.Equal(t, InviteTokenLength, expectedLength)
assert.Equal(t, 4, len(InviteTokenPrefix))
assert.Equal(t, 30, InviteTokenSecretLength)
assert.Equal(t, 6, InviteTokenChecksumLength)
assert.Equal(t, 40, InviteTokenLength)
assert.Equal(t, 259200, DefaultInviteExpirationSeconds) // 72 hours
assert.Equal(t, 3600, MinInviteExpirationSeconds) // 1 hour
}
func TestGenerateInviteToken_ValidatesOwnOutput(t *testing.T) {
// Generate multiple tokens and ensure they all validate
for i := 0; i < 50; i++ {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
err = ValidateInviteToken(plainToken)
assert.NoError(t, err, "Generated token should always be valid")
}
}
func TestHashInviteToken_MatchesGeneratedHash(t *testing.T) {
hashedToken, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
// HashInviteToken should produce the same hash as GenerateInviteToken
rehashedToken := HashInviteToken(plainToken)
assert.Equal(t, hashedToken, rehashedToken)
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"strings"
"time"
"unicode"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
@@ -1453,3 +1454,368 @@ func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, init
return nil
}
// CreateUserInvite creates an invite link for a new user in the embedded IdP.
// The user is NOT created until the invite is accepted.
func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
if !IsEmbeddedIdp(am.idpManager) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
if err := validateUserInvite(invite); err != nil {
return nil, err
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
// Check if user already exists in NetBird DB
existingUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
for _, user := range existingUsers {
if strings.EqualFold(user.Email, invite.Email) {
return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists")
}
}
// Check if invite already exists for this email
existingInvite, err := am.Store.GetUserInviteByEmail(ctx, store.LockingStrengthNone, accountID, invite.Email)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return nil, fmt.Errorf("failed to check existing invites: %w", err)
}
}
if existingInvite != nil {
return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email")
}
// Calculate expiration time
if expiresIn <= 0 {
expiresIn = types.DefaultInviteExpirationSeconds
}
if expiresIn < types.MinInviteExpirationSeconds {
return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour")
}
expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second)
// Generate invite token
inviteID := types.NewInviteID()
hashedToken, plainToken, err := types.GenerateInviteToken()
if err != nil {
return nil, fmt.Errorf("failed to generate invite token: %w", err)
}
// Create the invite record (no user created yet)
userInvite := &types.UserInviteRecord{
ID: inviteID,
AccountID: accountID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: invite.AutoGroups,
HashedToken: hashedToken,
ExpiresAt: expiresAt,
CreatedAt: time.Now().UTC(),
CreatedBy: initiatorUserID,
}
if err := am.Store.SaveUserInvite(ctx, userInvite); err != nil {
return nil, err
}
am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkCreated, map[string]any{"email": invite.Email})
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: inviteID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: invite.AutoGroups,
Status: string(types.UserStatusInvited),
Issued: types.UserIssuedAPI,
},
InviteToken: plainToken,
InviteExpiresAt: expiresAt,
}, nil
}
// GetUserInviteInfo retrieves invite information from a token (public endpoint).
func (am *DefaultAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) {
if err := types.ValidateInviteToken(token); err != nil {
return nil, status.Errorf(status.InvalidArgument, "invalid invite token: %v", err)
}
hashedToken := types.HashInviteToken(token)
invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthNone, hashedToken)
if err != nil {
return nil, err
}
// Get the inviter's name
invitedBy := ""
if invite.CreatedBy != "" {
inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, invite.CreatedBy)
if err == nil && inviter != nil {
invitedBy = inviter.Name
}
}
return &types.UserInviteInfo{
Email: invite.Email,
Name: invite.Name,
ExpiresAt: invite.ExpiresAt,
Valid: !invite.IsExpired(),
InvitedBy: invitedBy,
}, nil
}
// ListUserInvites returns all invites for an account.
func (am *DefaultAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
if !IsEmbeddedIdp(am.idpManager) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
records, err := am.Store.GetAccountUserInvites(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
invites := make([]*types.UserInvite, 0, len(records))
for _, record := range records {
invites = append(invites, &types.UserInvite{
UserInfo: &types.UserInfo{
ID: record.ID,
Email: record.Email,
Name: record.Name,
Role: record.Role,
AutoGroups: record.AutoGroups,
},
InviteExpiresAt: record.ExpiresAt,
InviteCreatedAt: record.CreatedAt,
})
}
return invites, nil
}
// AcceptUserInvite accepts an invite and creates the user in both IdP and NetBird DB.
func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error {
if !IsEmbeddedIdp(am.idpManager) {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
if password == "" {
return status.Errorf(status.InvalidArgument, "password is required")
}
if err := validatePassword(password); err != nil {
return status.Errorf(status.InvalidArgument, "invalid password: %v", err)
}
if err := types.ValidateInviteToken(token); err != nil {
return status.Errorf(status.InvalidArgument, "invalid invite token: %v", err)
}
hashedToken := types.HashInviteToken(token)
invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthUpdate, hashedToken)
if err != nil {
return err
}
if invite.IsExpired() {
return status.Errorf(status.InvalidArgument, "invite has expired")
}
// Create user in Dex with the provided password
embeddedIdp, ok := am.idpManager.(*idp.EmbeddedIdPManager)
if !ok {
return status.Errorf(status.Internal, "failed to get embedded IdP manager")
}
idpUser, err := embeddedIdp.CreateUserWithPassword(ctx, invite.Email, password, invite.Name)
if err != nil {
return fmt.Errorf("failed to create user in IdP: %w", err)
}
// Create user in NetBird DB
newUser := &types.User{
Id: idpUser.ID,
AccountID: invite.AccountID,
Role: types.StrRoleToUserRole(invite.Role),
AutoGroups: invite.AutoGroups,
Issued: types.UserIssuedAPI,
CreatedAt: time.Now().UTC(),
Email: invite.Email,
Name: invite.Name,
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := transaction.SaveUser(ctx, newUser); err != nil {
return fmt.Errorf("failed to save user: %w", err)
}
if err := transaction.DeleteUserInvite(ctx, invite.ID); err != nil {
return fmt.Errorf("failed to delete invite: %w", err)
}
return nil
})
if err != nil {
// Best-effort rollback: delete the IdP user to avoid orphaned records
if deleteErr := embeddedIdp.DeleteUser(ctx, idpUser.ID); deleteErr != nil {
log.WithContext(ctx).WithError(deleteErr).Errorf("failed to rollback IdP user %s after transaction failure", idpUser.ID)
}
return err
}
am.StoreEvent(ctx, newUser.Id, newUser.Id, invite.AccountID, activity.UserInviteLinkAccepted, map[string]any{"email": invite.Email})
return nil
}
// RegenerateUserInvite creates a new invite token for an existing invite, invalidating the previous one.
func (am *DefaultAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
if !IsEmbeddedIdp(am.idpManager) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
// Get existing invite
existingInvite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID)
if err != nil {
return nil, err
}
// Calculate expiration time
if expiresIn <= 0 {
expiresIn = types.DefaultInviteExpirationSeconds
}
if expiresIn < types.MinInviteExpirationSeconds {
return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour")
}
expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second)
// Generate new invite token
hashedToken, plainToken, err := types.GenerateInviteToken()
if err != nil {
return nil, fmt.Errorf("failed to generate invite token: %w", err)
}
// Update existing invite with new token and expiration
existingInvite.HashedToken = hashedToken
existingInvite.ExpiresAt = expiresAt
existingInvite.CreatedBy = initiatorUserID
err = am.Store.SaveUserInvite(ctx, existingInvite)
if err != nil {
return nil, err
}
am.StoreEvent(ctx, initiatorUserID, existingInvite.ID, accountID, activity.UserInviteLinkRegenerated, map[string]any{"email": existingInvite.Email})
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: existingInvite.ID,
Email: existingInvite.Email,
Name: existingInvite.Name,
Role: existingInvite.Role,
AutoGroups: existingInvite.AutoGroups,
Status: string(types.UserStatusInvited),
Issued: types.UserIssuedAPI,
},
InviteToken: plainToken,
InviteExpiresAt: expiresAt,
}, nil
}
// DeleteUserInvite deletes an existing invite by ID.
func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
if !IsEmbeddedIdp(am.idpManager) {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
invite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID)
if err != nil {
return err
}
if err := am.Store.DeleteUserInvite(ctx, inviteID); err != nil {
return err
}
am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkDeleted, map[string]any{"email": invite.Email})
return nil
}
const minPasswordLength = 8
// validatePassword checks password strength requirements:
// - Minimum 8 characters
// - At least 1 digit
// - At least 1 uppercase letter
// - At least 1 special character
func validatePassword(password string) error {
if len(password) < minPasswordLength {
return errors.New("password must be at least 8 characters long")
}
var hasDigit, hasUpper, hasSpecial bool
for _, c := range password {
switch {
case unicode.IsDigit(c):
hasDigit = true
case unicode.IsUpper(c):
hasUpper = true
case !unicode.IsLetter(c) && !unicode.IsDigit(c):
hasSpecial = true
}
}
var missing []string
if !hasDigit {
missing = append(missing, "one digit")
}
if !hasUpper {
missing = append(missing, "one uppercase letter")
}
if !hasSpecial {
missing = append(missing, "one special character")
}
if len(missing) > 0 {
return errors.New("password must contain at least " + strings.Join(missing, ", "))
}
return nil
}

File diff suppressed because it is too large Load Diff