Add user invite link feature for embedded IdP (#5157)

This commit is contained in:
Misha Bragin
2026-01-27 09:42:20 +01:00
committed by GitHub
parent 44ab454a13
commit 7d791620a6
21 changed files with 4832 additions and 2 deletions

View File

@@ -68,6 +68,13 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
if err := bypass.AddBypassPath("/api/setup"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
// Public invite endpoints (tokens start with nbi_)
if err := bypass.AddBypassPath("/api/users/invites/nbi_*"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
@@ -132,6 +139,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
peers.AddEndpoints(accountManager, router, networkMapController)
users.AddEndpoints(accountManager, router)
users.AddInvitesEndpoints(accountManager, router)
users.AddPublicInvitesEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router)
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
@@ -145,6 +154,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
recordsManager.RegisterEndpoints(router, rManager)
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router)
// Mount embedded IdP handler at /oauth2 path if configured
if embeddedIdpEnabled {

View File

@@ -28,6 +28,15 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
router.HandleFunc("/setup", h.setup).Methods("POST", "OPTIONS")
}
// AddVersionEndpoint registers the authenticated version endpoint.
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) {
h := &handler{
instanceManager: instanceManager,
}
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS")
}
// getInstanceStatus returns the instance status including whether setup is required.
// This endpoint is unauthenticated.
func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
@@ -65,3 +74,29 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
Email: userData.Email,
})
}
// getVersionInfo returns version information for NetBird components.
// This endpoint requires authentication.
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) {
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)
util.WriteErrorResponse("failed to get version info", http.StatusInternalServerError, w)
return
}
resp := api.InstanceVersionInfo{
ManagementCurrentVersion: versionInfo.CurrentVersion,
ManagementUpdateAvailable: versionInfo.ManagementUpdateAvailable,
}
if versionInfo.DashboardVersion != "" {
resp.DashboardAvailableVersion = &versionInfo.DashboardVersion
}
if versionInfo.ManagementVersion != "" {
resp.ManagementAvailableVersion = &versionInfo.ManagementVersion
}
util.WriteJSONObject(r.Context(), w, resp)
}

View File

@@ -25,6 +25,7 @@ type mockInstanceManager struct {
isSetupRequired bool
isSetupRequiredFn func(ctx context.Context) (bool, error)
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error)
}
func (m *mockInstanceManager) IsSetupRequired(ctx context.Context) (bool, error) {
@@ -66,6 +67,18 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo
}, nil
}
func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) {
if m.getVersionInfoFn != nil {
return m.getVersionInfoFn(ctx)
}
return &nbinstance.VersionInfo{
CurrentVersion: "0.34.0",
DashboardVersion: "2.0.0",
ManagementVersion: "0.35.0",
ManagementUpdateAvailable: true,
}, nil
}
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
@@ -279,3 +292,44 @@ func TestSetup_ManagerError(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestGetVersionInfo_Success(t *testing.T) {
manager := &mockInstanceManager{}
router := mux.NewRouter()
AddVersionEndpoint(manager, router)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.InstanceVersionInfo
err := json.NewDecoder(rec.Body).Decode(&response)
require.NoError(t, err)
assert.Equal(t, "0.34.0", response.ManagementCurrentVersion)
assert.NotNil(t, response.DashboardAvailableVersion)
assert.Equal(t, "2.0.0", *response.DashboardAvailableVersion)
assert.NotNil(t, response.ManagementAvailableVersion)
assert.Equal(t, "0.35.0", *response.ManagementAvailableVersion)
assert.True(t, response.ManagementUpdateAvailable)
}
func TestGetVersionInfo_Error(t *testing.T) {
manager := &mockInstanceManager{
getVersionInfoFn: func(ctx context.Context) (*nbinstance.VersionInfo, error) {
return nil, errors.New("failed to fetch versions")
},
}
router := mux.NewRouter()
AddVersionEndpoint(manager, router)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}

View File

@@ -0,0 +1,263 @@
package users
import (
"encoding/json"
"errors"
"io"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// publicInviteRateLimiter limits public invite requests by IP address to prevent brute-force attacks
var publicInviteRateLimiter = middleware.NewAPIRateLimiter(&middleware.RateLimiterConfig{
RequestsPerMinute: 10, // 10 attempts per minute per IP
Burst: 5, // Allow burst of 5 requests
CleanupInterval: 10 * time.Minute,
LimiterTTL: 30 * time.Minute,
})
// toUserInviteResponse converts a UserInvite to an API response.
func toUserInviteResponse(invite *types.UserInvite) api.UserInvite {
autoGroups := invite.UserInfo.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
}
var inviteLink *string
if invite.InviteToken != "" {
inviteLink = &invite.InviteToken
}
return api.UserInvite{
Id: invite.UserInfo.ID,
Email: invite.UserInfo.Email,
Name: invite.UserInfo.Name,
Role: invite.UserInfo.Role,
AutoGroups: autoGroups,
ExpiresAt: invite.InviteExpiresAt.UTC(),
CreatedAt: invite.InviteCreatedAt.UTC(),
Expired: time.Now().After(invite.InviteExpiresAt),
InviteToken: inviteLink,
}
}
// invitesHandler handles user invite operations
type invitesHandler struct {
accountManager account.Manager
}
// AddInvitesEndpoints registers invite-related endpoints
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
h := &invitesHandler{accountManager: accountManager}
// Authenticated endpoints (require admin)
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
}
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
h := &invitesHandler{accountManager: accountManager}
// Create a subrouter for public invite endpoints with rate limiting middleware
publicRouter := router.PathPrefix("/users/invites").Subrouter()
publicRouter.Use(publicInviteRateLimiter.Middleware)
// Public endpoints (no auth required, protected by token and rate limited)
publicRouter.HandleFunc("/{token}", h.getInviteInfo).Methods("GET", "OPTIONS")
publicRouter.HandleFunc("/{token}/accept", h.acceptInvite).Methods("POST", "OPTIONS")
}
// listInvites handles GET /api/users/invites
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := make([]api.UserInvite, 0, len(invites))
for _, invite := range invites {
resp = append(resp, toUserInviteResponse(invite))
}
util.WriteJSONObject(r.Context(), w, resp)
}
// createInvite handles POST /api/users/invites
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.UserInviteCreateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
invite := &types.UserInfo{
Email: req.Email,
Name: req.Name,
Role: req.Role,
AutoGroups: req.AutoGroups,
}
expiresIn := 0
if req.ExpiresIn != nil {
expiresIn = *req.ExpiresIn
}
result, err := h.accountManager.CreateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, invite, expiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
result.InviteCreatedAt = time.Now().UTC()
resp := toUserInviteResponse(result)
util.WriteJSONObject(r.Context(), w, &resp)
}
// getInviteInfo handles GET /api/users/invites/{token}
func (h *invitesHandler) getInviteInfo(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
token := vars["token"]
if token == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
return
}
info, err := h.accountManager.GetUserInviteInfo(r.Context(), token)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
expiresAt := info.ExpiresAt.UTC()
util.WriteJSONObject(r.Context(), w, &api.UserInviteInfo{
Email: info.Email,
Name: info.Name,
ExpiresAt: expiresAt,
Valid: info.Valid,
InvitedBy: info.InvitedBy,
})
}
// acceptInvite handles POST /api/users/invites/{token}/accept
func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
token := vars["token"]
if token == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
return
}
var req api.UserInviteAcceptRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err := h.accountManager.AcceptUserInvite(r.Context(), token, req.Password)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, &api.UserInviteAcceptResponse{Success: true})
}
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
return
}
var req api.UserInviteRegenerateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
// Allow empty body (io.EOF) - expiresIn is optional
if !errors.Is(err, io.EOF) {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
}
expiresIn := 0
if req.ExpiresIn != nil {
expiresIn = *req.ExpiresIn
}
result, err := h.accountManager.RegenerateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID, expiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
expiresAt := result.InviteExpiresAt.UTC()
util.WriteJSONObject(r.Context(), w, &api.UserInviteRegenerateResponse{
InviteToken: result.InviteToken,
InviteExpiresAt: expiresAt,
})
}
// deleteInvite handles DELETE /api/users/invites/{inviteId}
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
return
}
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -0,0 +1,642 @@
package users
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
testInviteID = "test-invite-id"
testInviteToken = "nbi_testtoken123456789012345678"
testEmail = "invite@example.com"
testName = "Test User"
)
func setupInvitesTestHandler(am *mock_server.MockAccountManager) *invitesHandler {
return &invitesHandler{
accountManager: am,
}
}
func TestListInvites(t *testing.T) {
now := time.Now().UTC()
testInvites := []*types.UserInvite{
{
UserInfo: &types.UserInfo{
ID: "invite-1",
Email: "user1@example.com",
Name: "User One",
Role: "user",
AutoGroups: []string{"group-1"},
},
InviteExpiresAt: now.Add(24 * time.Hour),
InviteCreatedAt: now,
},
{
UserInfo: &types.UserInfo{
ID: "invite-2",
Email: "user2@example.com",
Name: "User Two",
Role: "admin",
AutoGroups: nil,
},
InviteExpiresAt: now.Add(-1 * time.Hour), // Expired
InviteCreatedAt: now.Add(-48 * time.Hour),
},
}
tt := []struct {
name string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
expectedCount int
}{
{
name: "successful list",
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
return testInvites, nil
},
expectedCount: 2,
},
{
name: "empty list",
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
return []*types.UserInvite{}, nil
},
expectedCount: 0,
},
{
name: "permission denied",
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
return nil, status.NewPermissionDeniedError()
},
expectedCount: 0,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
ListUserInvitesFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodGet, "/api/users/invites", nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
rr := httptest.NewRecorder()
handler.listInvites(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp []api.UserInvite
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Len(t, resp, tc.expectedCount)
}
})
}
}
func TestCreateInvite(t *testing.T) {
now := time.Now().UTC()
expiresAt := now.Add(72 * time.Hour)
tt := []struct {
name string
requestBody string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
}{
{
name: "successful create",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":["group-1"]}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: testInviteID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: invite.AutoGroups,
Status: string(types.UserStatusInvited),
},
InviteToken: testInviteToken,
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "successful create with custom expiration",
requestBody: `{"email":"test@example.com","name":"Test User","role":"admin","auto_groups":[],"expires_in":3600}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
assert.Equal(t, 3600, expiresIn)
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: testInviteID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: []string{},
Status: string(types.UserStatusInvited),
},
InviteToken: testInviteToken,
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "user already exists",
requestBody: `{"email":"existing@example.com","name":"Existing User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusConflict,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists")
},
},
{
name: "invite already exists",
requestBody: `{"email":"invited@example.com","name":"Invited User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusConflict,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email")
},
},
{
name: "permission denied",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.NewPermissionDeniedError()
},
},
{
name: "embedded IDP not enabled",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "invalid JSON",
requestBody: `{invalid json}`,
expectedStatus: http.StatusBadRequest,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
CreateUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodPost, "/api/users/invites", bytes.NewBufferString(tc.requestBody))
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
rr := httptest.NewRecorder()
handler.createInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInvite
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, testInviteID, resp.Id)
assert.NotNil(t, resp.InviteToken)
assert.NotEmpty(t, *resp.InviteToken)
}
})
}
}
func TestGetInviteInfo(t *testing.T) {
now := time.Now().UTC()
tt := []struct {
name string
token string
expectedStatus int
mockFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error)
}{
{
name: "successful get valid invite",
token: testInviteToken,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return &types.UserInviteInfo{
Email: testEmail,
Name: testName,
ExpiresAt: now.Add(24 * time.Hour),
Valid: true,
InvitedBy: "Admin User",
}, nil
},
},
{
name: "successful get expired invite",
token: testInviteToken,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return &types.UserInviteInfo{
Email: testEmail,
Name: testName,
ExpiresAt: now.Add(-24 * time.Hour),
Valid: false,
InvitedBy: "Admin User",
}, nil
},
},
{
name: "invite not found",
token: "nbi_invalidtoken1234567890123456",
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return nil, status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "invalid token format",
token: "invalid",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return nil, status.Errorf(status.InvalidArgument, "invalid invite token")
},
},
{
name: "missing token",
token: "",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
GetUserInviteInfoFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodGet, "/api/users/invites/"+tc.token, nil)
if tc.token != "" {
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
}
rr := httptest.NewRecorder()
handler.getInviteInfo(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInviteInfo
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, testEmail, resp.Email)
assert.Equal(t, testName, resp.Name)
}
})
}
}
func TestAcceptInvite(t *testing.T) {
tt := []struct {
name string
token string
requestBody string
expectedStatus int
mockFunc func(ctx context.Context, token, password string) error
}{
{
name: "successful accept",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, token, password string) error {
return nil
},
},
{
name: "invite not found",
token: "nbi_invalidtoken1234567890123456",
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "invite expired",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "invite has expired")
},
},
{
name: "embedded IDP not enabled",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "missing token",
token: "",
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
{
name: "invalid JSON",
token: testInviteToken,
requestBody: `{invalid}`,
expectedStatus: http.StatusBadRequest,
mockFunc: nil,
},
{
name: "password too short",
token: testInviteToken,
requestBody: `{"password":"Short1!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must be at least 8 characters long")
},
},
{
name: "password missing digit",
token: testInviteToken,
requestBody: `{"password":"NoDigitPass!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must contain at least one digit")
},
},
{
name: "password missing uppercase",
token: testInviteToken,
requestBody: `{"password":"nouppercase1!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must contain at least one uppercase letter")
},
},
{
name: "password missing special character",
token: testInviteToken,
requestBody: `{"password":"NoSpecial123"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must contain at least one special character")
},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
AcceptUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.token+"/accept", bytes.NewBufferString(tc.requestBody))
if tc.token != "" {
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
}
rr := httptest.NewRecorder()
handler.acceptInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInviteAcceptResponse
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.True(t, resp.Success)
}
})
}
}
func TestRegenerateInvite(t *testing.T) {
now := time.Now().UTC()
expiresAt := now.Add(72 * time.Hour)
tt := []struct {
name string
inviteID string
requestBody string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
}{
{
name: "successful regenerate with empty body",
inviteID: testInviteID,
requestBody: "",
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
assert.Equal(t, 0, expiresIn)
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: inviteID,
Email: testEmail,
},
InviteToken: "nbi_newtoken12345678901234567890",
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "successful regenerate with custom expiration",
inviteID: testInviteID,
requestBody: `{"expires_in":7200}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
assert.Equal(t, 7200, expiresIn)
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: inviteID,
Email: testEmail,
},
InviteToken: "nbi_newtoken12345678901234567890",
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "invite not found",
inviteID: "non-existent-invite",
requestBody: "",
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "permission denied",
inviteID: testInviteID,
requestBody: "",
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
return nil, status.NewPermissionDeniedError()
},
},
{
name: "missing invite ID",
inviteID: "",
requestBody: "",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
{
name: "invalid JSON should return error",
inviteID: testInviteID,
requestBody: `{invalid json}`,
expectedStatus: http.StatusBadRequest,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
RegenerateUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
var body io.Reader
if tc.requestBody != "" {
body = bytes.NewBufferString(tc.requestBody)
}
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.inviteID+"/regenerate", body)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
if tc.inviteID != "" {
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
}
rr := httptest.NewRecorder()
handler.regenerateInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInviteRegenerateResponse
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.NotEmpty(t, resp.InviteToken)
}
})
}
}
func TestDeleteInvite(t *testing.T) {
tt := []struct {
name string
inviteID string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
}{
{
name: "successful delete",
inviteID: testInviteID,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return nil
},
},
{
name: "invite not found",
inviteID: "non-existent-invite",
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "permission denied",
inviteID: testInviteID,
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return status.NewPermissionDeniedError()
},
},
{
name: "embedded IDP not enabled",
inviteID: testInviteID,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "missing invite ID",
inviteID: "",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
DeleteUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodDelete, "/api/users/invites/"+tc.inviteID, nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
if tc.inviteID != "" {
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
}
rr := httptest.NewRecorder()
handler.deleteInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
})
}
}

View File

@@ -2,10 +2,14 @@ package middleware
import (
"context"
"net"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
"github.com/netbirdio/netbird/shared/management/http/util"
)
// RateLimiterConfig holds configuration for the API rate limiter
@@ -144,3 +148,25 @@ func (rl *APIRateLimiter) Reset(key string) {
defer rl.mu.Unlock()
delete(rl.limiters, key)
}
// Middleware returns an HTTP middleware that rate limits requests by client IP.
// Returns 429 Too Many Requests if the rate limit is exceeded.
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
if !rl.Allow(clientIP) {
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
return
}
next.ServeHTTP(w, r)
})
}
// getClientIP extracts the client IP address from the request.
func getClientIP(r *http.Request) string {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}

View File

@@ -0,0 +1,158 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestAPIRateLimiter_Allow(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60, // 1 per second
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// First two requests should be allowed (burst)
assert.True(t, rl.Allow("test-key"))
assert.True(t, rl.Allow("test-key"))
// Third request should be denied (exceeded burst)
assert.False(t, rl.Allow("test-key"))
// Different key should be allowed
assert.True(t, rl.Allow("different-key"))
}
func TestAPIRateLimiter_Middleware(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60, // 1 per second
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// Create a simple handler that returns 200 OK
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Wrap with rate limiter middleware
handler := rl.Middleware(nextHandler)
// First two requests should pass (burst)
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, "request %d should be allowed", i+1)
}
// Third request should be rate limited
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusTooManyRequests, rr.Code)
}
func TestAPIRateLimiter_Middleware_DifferentIPs(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := rl.Middleware(nextHandler)
// Request from first IP
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
req1.RemoteAddr = "192.168.1.1:12345"
rr1 := httptest.NewRecorder()
handler.ServeHTTP(rr1, req1)
assert.Equal(t, http.StatusOK, rr1.Code)
// Second request from first IP should be rate limited
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
req2.RemoteAddr = "192.168.1.1:12345"
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
assert.Equal(t, http.StatusTooManyRequests, rr2.Code)
// Request from different IP should be allowed
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
req3.RemoteAddr = "192.168.1.2:12345"
rr3 := httptest.NewRecorder()
handler.ServeHTTP(rr3, req3)
assert.Equal(t, http.StatusOK, rr3.Code)
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
expected string
}{
{
name: "remote addr with port",
remoteAddr: "192.168.1.1:12345",
expected: "192.168.1.1",
},
{
name: "remote addr without port",
remoteAddr: "192.168.1.1",
expected: "192.168.1.1",
},
{
name: "IPv6 with port",
remoteAddr: "[::1]:12345",
expected: "::1",
},
{
name: "IPv6 without port",
remoteAddr: "::1",
expected: "::1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = tc.remoteAddr
assert.Equal(t, tc.expected, getClientIP(req))
})
}
}
func TestAPIRateLimiter_Reset(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// Use up the burst
assert.True(t, rl.Allow("test-key"))
assert.False(t, rl.Allow("test-key"))
// Reset the limiter
rl.Reset("test-key")
// Should be allowed again
assert.True(t, rl.Allow("test-key"))
}