[management, infrastructure, idp] Simplified IdP Management - Embedded IdP (#5008)

Embed Dex as a built-in IdP to simplify self-hosting setup.
Adds an embedded OIDC Identity Provider (Dex) with local user management and optional external IdP connectors (Google/GitHub/OIDC/SAML), plus device-auth flow for CLI login. Introduces instance onboarding/setup endpoints (including owner creation), field-level encryption for sensitive user data, a streamlined self-hosting provisioning script, and expanded APIs + test coverage for IdP management.

more at https://github.com/netbirdio/netbird/pull/5008#issuecomment-3718987393
This commit is contained in:
Misha Bragin
2026-01-07 08:52:32 -05:00
committed by GitHub
parent 5393ad948f
commit e586c20e36
90 changed files with 7702 additions and 517 deletions

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/gorilla/mux"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
@@ -29,6 +30,8 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
"github.com/netbirdio/netbird/management/server/http/handlers/events"
"github.com/netbirdio/netbird/management/server/http/handlers/groups"
"github.com/netbirdio/netbird/management/server/http/handlers/idp"
"github.com/netbirdio/netbird/management/server/http/handlers/instance"
"github.com/netbirdio/netbird/management/server/http/handlers/networks"
"github.com/netbirdio/netbird/management/server/http/handlers/peers"
"github.com/netbirdio/netbird/management/server/http/handlers/policies"
@@ -36,6 +39,8 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
"github.com/netbirdio/netbird/management/server/http/handlers/users"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
@@ -51,23 +56,15 @@ const (
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(
ctx context.Context,
accountManager account.Manager,
networksManager nbnetworks.Manager,
resourceManager resources.Manager,
routerManager routers.Manager,
groupsManager nbgroups.Manager,
LocationManager geolocation.Geolocation,
authManager auth.Manager,
appMetrics telemetry.AppMetrics,
integratedValidator integrated_validator.IntegratedValidator,
proxyController port_forwarding.Controller,
permissionsManager permissions.Manager,
peersManager nbpeers.Manager,
settingsManager settings.Manager,
networkMapController network_map.Controller,
) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
if err := bypass.AddBypassPath("/api/setup"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
@@ -122,7 +119,14 @@ func NewAPIHandler(
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
accounts.AddEndpoints(accountManager, settingsManager, router)
// Check if embedded IdP is enabled
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
if err != nil {
return nil, fmt.Errorf("failed to create instance manager: %w", err)
}
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
peers.AddEndpoints(accountManager, router, networkMapController)
users.AddEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router)
@@ -134,6 +138,13 @@ func NewAPIHandler(
dns.AddEndpoints(accountManager, router)
events.AddEndpoints(accountManager, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
// Mount embedded IdP handler at /oauth2 path if configured
if embeddedIdpEnabled {
rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler()))
}
return rootRouter, nil
}

View File

@@ -36,22 +36,24 @@ const (
// handler is a handler that handles the server.Account HTTP endpoints
type handler struct {
accountManager account.Manager
settingsManager settings.Manager
accountManager account.Manager
settingsManager settings.Manager
embeddedIdpEnabled bool
}
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
accountsHandler := newHandler(accountManager, settingsManager)
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool, router *mux.Router) {
accountsHandler := newHandler(accountManager, settingsManager, embeddedIdpEnabled)
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
}
// newHandler creates a new handler HTTP handler
func newHandler(accountManager account.Manager, settingsManager settings.Manager) *handler {
func newHandler(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool) *handler {
return &handler{
accountManager: accountManager,
settingsManager: settingsManager,
accountManager: accountManager,
settingsManager: settingsManager,
embeddedIdpEnabled: embeddedIdpEnabled,
}
}
@@ -163,7 +165,7 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
return
}
resp := toAccountResponse(accountID, settings, meta, onboarding)
resp := toAccountResponse(accountID, settings, meta, onboarding, h.embeddedIdpEnabled)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
@@ -290,7 +292,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
return
}
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding)
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding, h.embeddedIdpEnabled)
util.WriteJSONObject(r.Context(), w, &resp)
}
@@ -319,7 +321,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account {
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding, embeddedIdpEnabled bool) *api.Account {
jwtAllowGroups := settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}
@@ -339,6 +341,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
DnsDomain: &settings.DNSDomain,
AutoUpdateVersion: &settings.AutoUpdateVersion,
EmbeddedIdpEnabled: &embeddedIdpEnabled,
}
if settings.NetworkRange.IsValid() {

View File

@@ -33,6 +33,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
AnyTimes()
return &handler{
embeddedIdpEnabled: false,
accountManager: &mock_server.MockAccountManager{
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
return account.Settings, nil
@@ -122,6 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
},
expectedArray: true,
expectedID: accountID,
@@ -145,6 +147,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -168,6 +171,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr("latest"),
EmbeddedIdpEnabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -191,6 +195,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -214,6 +219,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -237,6 +243,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
},
expectedArray: false,
expectedID: accountID,

View File

@@ -0,0 +1,196 @@
package idp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// handler handles identity provider HTTP endpoints
type handler struct {
accountManager account.Manager
}
// AddEndpoints registers identity provider endpoints
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
h := newHandler(accountManager)
router.HandleFunc("/identity-providers", h.getAllIdentityProviders).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers", h.createIdentityProvider).Methods("POST", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE", "OPTIONS")
}
func newHandler(accountManager account.Manager) *handler {
return &handler{
accountManager: accountManager,
}
}
// getAllIdentityProviders returns all identity providers for the account
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
providers, err := h.accountManager.GetIdentityProviders(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
response := make([]api.IdentityProvider, 0, len(providers))
for _, p := range providers {
response = append(response, toAPIResponse(p))
}
util.WriteJSONObject(r.Context(), w, response)
}
// getIdentityProvider returns a specific identity provider
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "identity provider ID is required"), w)
return
}
provider, err := h.accountManager.GetIdentityProvider(r.Context(), accountID, idpID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toAPIResponse(provider))
}
// createIdentityProvider creates a new identity provider
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.IdentityProviderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
idp := fromAPIRequest(&req)
created, err := h.accountManager.CreateIdentityProvider(r.Context(), accountID, userID, idp)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toAPIResponse(created))
}
// updateIdentityProvider updates an existing identity provider
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "identity provider ID is required"), w)
return
}
var req api.IdentityProviderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
idp := fromAPIRequest(&req)
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), accountID, idpID, userID, idp)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toAPIResponse(updated))
}
// deleteIdentityProvider deletes an identity provider
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "identity provider ID is required"), w)
return
}
if err := h.accountManager.DeleteIdentityProvider(r.Context(), accountID, idpID, userID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toAPIResponse(idp *types.IdentityProvider) api.IdentityProvider {
resp := api.IdentityProvider{
Type: api.IdentityProviderType(idp.Type),
Name: idp.Name,
Issuer: idp.Issuer,
ClientId: idp.ClientID,
}
if idp.ID != "" {
resp.Id = &idp.ID
}
// Note: ClientSecret is never returned in responses for security
return resp
}
func fromAPIRequest(req *api.IdentityProviderRequest) *types.IdentityProvider {
return &types.IdentityProvider{
Type: types.IdentityProviderType(req.Type),
Name: req.Name,
Issuer: req.Issuer,
ClientID: req.ClientId,
ClientSecret: req.ClientSecret,
}
}

View File

@@ -0,0 +1,438 @@
package idp
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
existingIDPID = "existing-idp-id"
newIDPID = "new-idp-id"
)
func initIDPTestData(existingIDP *types.IdentityProvider) *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
GetIdentityProvidersFunc: func(_ context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
if accountID != testAccountID {
return nil, status.Errorf(status.NotFound, "account not found")
}
if existingIDP != nil {
return []*types.IdentityProvider{existingIDP}, nil
}
return []*types.IdentityProvider{}, nil
},
GetIdentityProviderFunc: func(_ context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) {
if accountID != testAccountID {
return nil, status.Errorf(status.NotFound, "account not found")
}
if existingIDP != nil && idpID == existingIDP.ID {
return existingIDP, nil
}
return nil, status.Errorf(status.NotFound, "identity provider not found")
},
CreateIdentityProviderFunc: func(_ context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) {
if accountID != testAccountID {
return nil, status.Errorf(status.NotFound, "account not found")
}
if idp.Name == "" {
return nil, status.Errorf(status.InvalidArgument, "name is required")
}
created := idp.Copy()
created.ID = newIDPID
created.AccountID = accountID
return created, nil
},
UpdateIdentityProviderFunc: func(_ context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) {
if accountID != testAccountID {
return nil, status.Errorf(status.NotFound, "account not found")
}
if existingIDP == nil || idpID != existingIDP.ID {
return nil, status.Errorf(status.NotFound, "identity provider not found")
}
updated := idp.Copy()
updated.ID = idpID
updated.AccountID = accountID
return updated, nil
},
DeleteIdentityProviderFunc: func(_ context.Context, accountID, idpID, userID string) error {
if accountID != testAccountID {
return status.Errorf(status.NotFound, "account not found")
}
if existingIDP == nil || idpID != existingIDP.ID {
return status.Errorf(status.NotFound, "identity provider not found")
}
return nil
},
},
}
}
func TestGetAllIdentityProviders(t *testing.T) {
existingIDP := &types.IdentityProvider{
ID: existingIDPID,
Name: "Test IDP",
Type: types.IdentityProviderTypeOIDC,
Issuer: "https://issuer.example.com",
ClientID: "client-id",
}
tt := []struct {
name string
expectedStatus int
expectedCount int
}{
{
name: "Get All Identity Providers",
expectedStatus: http.StatusOK,
expectedCount: 1,
},
}
h := initIDPTestData(existingIDP)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/identity-providers", nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers", h.getAllIdentityProviders).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
assert.Equal(t, tc.expectedStatus, recorder.Code)
content, err := io.ReadAll(res.Body)
require.NoError(t, err)
var idps []api.IdentityProvider
err = json.Unmarshal(content, &idps)
require.NoError(t, err)
assert.Len(t, idps, tc.expectedCount)
})
}
}
func TestGetIdentityProvider(t *testing.T) {
existingIDP := &types.IdentityProvider{
ID: existingIDPID,
Name: "Test IDP",
Type: types.IdentityProviderTypeOIDC,
Issuer: "https://issuer.example.com",
ClientID: "client-id",
}
tt := []struct {
name string
idpID string
expectedStatus int
expectedBody bool
}{
{
name: "Get Existing Identity Provider",
idpID: existingIDPID,
expectedStatus: http.StatusOK,
expectedBody: true,
},
{
name: "Get Non-Existing Identity Provider",
idpID: "non-existing-id",
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
}
h := initIDPTestData(existingIDP)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/identity-providers/%s", tc.idpID), nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
assert.Equal(t, tc.expectedStatus, recorder.Code)
if tc.expectedBody {
content, err := io.ReadAll(res.Body)
require.NoError(t, err)
var idp api.IdentityProvider
err = json.Unmarshal(content, &idp)
require.NoError(t, err)
assert.Equal(t, existingIDPID, *idp.Id)
assert.Equal(t, existingIDP.Name, idp.Name)
}
})
}
}
func TestCreateIdentityProvider(t *testing.T) {
tt := []struct {
name string
requestBody string
expectedStatus int
expectedBody bool
}{
{
name: "Create Identity Provider",
requestBody: `{
"name": "New IDP",
"type": "oidc",
"issuer": "https://new-issuer.example.com",
"client_id": "new-client-id",
"client_secret": "new-client-secret"
}`,
expectedStatus: http.StatusOK,
expectedBody: true,
},
{
name: "Create Identity Provider with Invalid JSON",
requestBody: `{invalid json`,
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
}
h := initIDPTestData(nil)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/identity-providers", bytes.NewBufferString(tc.requestBody))
req.Header.Set("Content-Type", "application/json")
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers", h.createIdentityProvider).Methods("POST")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
assert.Equal(t, tc.expectedStatus, recorder.Code)
if tc.expectedBody {
content, err := io.ReadAll(res.Body)
require.NoError(t, err)
var idp api.IdentityProvider
err = json.Unmarshal(content, &idp)
require.NoError(t, err)
assert.Equal(t, newIDPID, *idp.Id)
assert.Equal(t, "New IDP", idp.Name)
assert.Equal(t, api.IdentityProviderTypeOidc, idp.Type)
}
})
}
}
func TestUpdateIdentityProvider(t *testing.T) {
existingIDP := &types.IdentityProvider{
ID: existingIDPID,
Name: "Test IDP",
Type: types.IdentityProviderTypeOIDC,
Issuer: "https://issuer.example.com",
ClientID: "client-id",
ClientSecret: "client-secret",
}
tt := []struct {
name string
idpID string
requestBody string
expectedStatus int
expectedBody bool
}{
{
name: "Update Existing Identity Provider",
idpID: existingIDPID,
requestBody: `{
"name": "Updated IDP",
"type": "oidc",
"issuer": "https://updated-issuer.example.com",
"client_id": "updated-client-id"
}`,
expectedStatus: http.StatusOK,
expectedBody: true,
},
{
name: "Update Non-Existing Identity Provider",
idpID: "non-existing-id",
requestBody: `{
"name": "Updated IDP",
"type": "oidc",
"issuer": "https://updated-issuer.example.com",
"client_id": "updated-client-id"
}`,
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
{
name: "Update Identity Provider with Invalid JSON",
idpID: existingIDPID,
requestBody: `{invalid json`,
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
}
h := initIDPTestData(existingIDP)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/api/identity-providers/%s", tc.idpID), bytes.NewBufferString(tc.requestBody))
req.Header.Set("Content-Type", "application/json")
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
assert.Equal(t, tc.expectedStatus, recorder.Code)
if tc.expectedBody {
content, err := io.ReadAll(res.Body)
require.NoError(t, err)
var idp api.IdentityProvider
err = json.Unmarshal(content, &idp)
require.NoError(t, err)
assert.Equal(t, existingIDPID, *idp.Id)
assert.Equal(t, "Updated IDP", idp.Name)
}
})
}
}
func TestDeleteIdentityProvider(t *testing.T) {
existingIDP := &types.IdentityProvider{
ID: existingIDPID,
Name: "Test IDP",
Type: types.IdentityProviderTypeOIDC,
Issuer: "https://issuer.example.com",
ClientID: "client-id",
}
tt := []struct {
name string
idpID string
expectedStatus int
}{
{
name: "Delete Existing Identity Provider",
idpID: existingIDPID,
expectedStatus: http.StatusOK,
},
{
name: "Delete Non-Existing Identity Provider",
idpID: "non-existing-id",
expectedStatus: http.StatusNotFound,
},
}
h := initIDPTestData(existingIDP)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/api/identity-providers/%s", tc.idpID), nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
assert.Equal(t, tc.expectedStatus, recorder.Code)
})
}
}
func TestToAPIResponse(t *testing.T) {
idp := &types.IdentityProvider{
ID: "test-id",
Name: "Test IDP",
Type: types.IdentityProviderTypeGoogle,
Issuer: "https://accounts.google.com",
ClientID: "client-id",
ClientSecret: "should-not-be-returned",
}
response := toAPIResponse(idp)
assert.Equal(t, "test-id", *response.Id)
assert.Equal(t, "Test IDP", response.Name)
assert.Equal(t, api.IdentityProviderTypeGoogle, response.Type)
assert.Equal(t, "https://accounts.google.com", response.Issuer)
assert.Equal(t, "client-id", response.ClientId)
// Note: ClientSecret is not included in response type by design
}
func TestFromAPIRequest(t *testing.T) {
req := &api.IdentityProviderRequest{
Name: "New IDP",
Type: api.IdentityProviderTypeOkta,
Issuer: "https://dev-123456.okta.com",
ClientId: "okta-client-id",
ClientSecret: "okta-client-secret",
}
idp := fromAPIRequest(req)
assert.Equal(t, "New IDP", idp.Name)
assert.Equal(t, types.IdentityProviderTypeOkta, idp.Type)
assert.Equal(t, "https://dev-123456.okta.com", idp.Issuer)
assert.Equal(t, "okta-client-id", idp.ClientID)
assert.Equal(t, "okta-client-secret", idp.ClientSecret)
}

View File

@@ -0,0 +1,67 @@
package instance
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
// handler handles the instance setup HTTP endpoints
type handler struct {
instanceManager nbinstance.Manager
}
// AddEndpoints registers the instance setup endpoints.
// These endpoints bypass authentication for initial setup.
func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
h := &handler{
instanceManager: instanceManager,
}
router.HandleFunc("/instance", h.getInstanceStatus).Methods("GET", "OPTIONS")
router.HandleFunc("/setup", h.setup).Methods("POST", "OPTIONS")
}
// getInstanceStatus returns the instance status including whether setup is required.
// This endpoint is unauthenticated.
func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
setupRequired, err := h.instanceManager.IsSetupRequired(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to check setup status: %v", err)
util.WriteErrorResponse("failed to check instance status", http.StatusInternalServerError, w)
return
}
util.WriteJSONObject(r.Context(), w, api.InstanceStatus{
SetupRequired: setupRequired,
})
}
// setup creates the initial admin user for the instance.
// This endpoint is unauthenticated but only works when setup is required.
func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
var req api.SetupRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
return
}
userData, err := h.instanceManager.CreateOwnerUser(r.Context(), req.Email, req.Password, req.Name)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
log.WithContext(r.Context()).Infof("instance setup completed: created user %s", req.Email)
util.WriteJSONObject(r.Context(), w, api.SetupResponse{
UserId: userData.ID,
Email: userData.Email,
})
}

View File

@@ -0,0 +1,281 @@
package instance
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"net/mail"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/idp"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
// mockInstanceManager implements instance.Manager for testing
type mockInstanceManager struct {
isSetupRequired bool
isSetupRequiredFn func(ctx context.Context) (bool, error)
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
}
func (m *mockInstanceManager) IsSetupRequired(ctx context.Context) (bool, error) {
if m.isSetupRequiredFn != nil {
return m.isSetupRequiredFn(ctx)
}
return m.isSetupRequired, nil
}
func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.createOwnerUserFn != nil {
return m.createOwnerUserFn(ctx, email, password, name)
}
// Default mock includes validation like the real manager
if !m.isSetupRequired {
return nil, status.Errorf(status.PreconditionFailed, "setup already completed")
}
if email == "" {
return nil, status.Errorf(status.InvalidArgument, "email is required")
}
if _, err := mail.ParseAddress(email); err != nil {
return nil, status.Errorf(status.InvalidArgument, "invalid email format")
}
if name == "" {
return nil, status.Errorf(status.InvalidArgument, "name is required")
}
if password == "" {
return nil, status.Errorf(status.InvalidArgument, "password is required")
}
if len(password) < 8 {
return nil, status.Errorf(status.InvalidArgument, "password must be at least 8 characters")
}
return &idp.UserData{
ID: "test-user-id",
Email: email,
Name: name,
}, nil
}
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
router := mux.NewRouter()
AddEndpoints(manager, router)
return router
}
func TestGetInstanceStatus_SetupRequired(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouter(manager)
req := httptest.NewRequest(http.MethodGet, "/instance", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.InstanceStatus
err := json.NewDecoder(rec.Body).Decode(&response)
require.NoError(t, err)
assert.True(t, response.SetupRequired)
}
func TestGetInstanceStatus_SetupNotRequired(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: false}
router := setupTestRouter(manager)
req := httptest.NewRequest(http.MethodGet, "/instance", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.InstanceStatus
err := json.NewDecoder(rec.Body).Decode(&response)
require.NoError(t, err)
assert.False(t, response.SetupRequired)
}
func TestGetInstanceStatus_Error(t *testing.T) {
manager := &mockInstanceManager{
isSetupRequiredFn: func(ctx context.Context) (bool, error) {
return false, errors.New("database error")
},
}
router := setupTestRouter(manager)
req := httptest.NewRequest(http.MethodGet, "/instance", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestSetup_Success(t *testing.T) {
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
assert.Equal(t, "admin@example.com", email)
assert.Equal(t, "securepassword123", password)
assert.Equal(t, "Admin User", name)
return &idp.UserData{
ID: "created-user-id",
Email: email,
Name: name,
}, nil
},
}
router := setupTestRouter(manager)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin User"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.SetupResponse
err := json.NewDecoder(rec.Body).Decode(&response)
require.NoError(t, err)
assert.Equal(t, "created-user-id", response.UserId)
assert.Equal(t, "admin@example.com", response.Email)
}
func TestSetup_AlreadyCompleted(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: false}
router := setupTestRouter(manager)
body := `{"email": "admin@example.com", "password": "securepassword123"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusPreconditionFailed, rec.Code)
}
func TestSetup_MissingEmail(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouter(manager)
body := `{"password": "securepassword123"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
}
func TestSetup_InvalidEmail(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouter(manager)
body := `{"email": "not-an-email", "password": "securepassword123", "name": "User"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
// Note: Invalid email format uses mail.ParseAddress which is treated differently
// and returns 400 Bad Request instead of 422 Unprocessable Entity
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
}
func TestSetup_MissingPassword(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouter(manager)
body := `{"email": "admin@example.com", "name": "User"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
}
func TestSetup_PasswordTooShort(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouter(manager)
body := `{"email": "admin@example.com", "password": "short", "name": "User"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
}
func TestSetup_InvalidJSON(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouter(manager)
body := `{invalid json}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestSetup_CreateUserError(t *testing.T) {
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return nil, errors.New("user creation failed")
},
}
router := setupTestRouter(manager)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "User"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestSetup_ManagerError(t *testing.T) {
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return nil, status.Errorf(status.Internal, "database error")
},
}
router := setupTestRouter(manager)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "User"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}

View File

@@ -66,7 +66,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
},
}
srvUser := types.NewRegularUser(serviceUser)
srvUser := types.NewRegularUser(serviceUser, "", "")
srvUser.IsServiceUser = true
account := &types.Account{
@@ -75,7 +75,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
Peers: peersMap,
Users: map[string]*types.User{
adminUser: types.NewAdminUser(adminUser),
regularUser: types.NewRegularUser(regularUser),
regularUser: types.NewRegularUser(regularUser, "", ""),
serviceUser: srvUser,
},
Groups: map[string]*types.Group{

View File

@@ -326,6 +326,16 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
isCurrent := user.ID == currenUserID
var password *string
if user.Password != "" {
password = &user.Password
}
var idpID *string
if user.IdPID != "" {
idpID = &user.IdPID
}
return &api.User{
Id: user.ID,
Name: user.Name,
@@ -339,6 +349,8 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
LastLogin: &user.LastLogin,
Issued: &user.Issued,
PendingApproval: user.PendingApproval,
Password: password,
IdpId: idpID,
}
}

View File

@@ -134,6 +134,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
userAuth.IsChild = ok
}
// Email is now extracted in ToUserAuth (from claims or userinfo endpoint)
// Available as userAuth.Email
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
accountId, _, err := m.ensureAccount(ctx, userAuth)
if err != nil {

View File

@@ -94,7 +94,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
groupsManagerMock := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}