Compare commits

...

4 Commits

Author SHA1 Message Date
Ashley Mensah
9291e3134b Add API for managing external identity provider connectors through
Zitadel. Supports OIDC, LDAP, and SAML connector types.

New endpoints:
- GET/DELETE /api/connectors - list and delete connectors
- POST /api/connectors/oidc - add OIDC connector
- POST /api/connectors/ldap - add LDAP connector
- POST /api/connectors/saml - add SAML connector
- POST /api/connectors/{id}/activate - activate connector
- POST /api/connectors/{id}/deactivate - deactivate connector
2025-12-19 19:48:23 +01:00
Ashley Mensah
eb578146e4 refactor(idp): make NetBird single source of truth for authorization
Remove duplicate authorization data from Zitadel IdP. NetBird now stores
all authorization data (account membership, invite status, roles) locally,
while Zitadel only stores identity information (email, name, credentials).

Changes:
- Add PendingInvite field to User struct to track invite status locally
- Simplify IdP Manager interface: remove metadata methods, add GetAllUsers
- Update cache warming to match IdP users against NetBird DB
- Remove addAccountIDToIDPAppMeta and all wt_* metadata writes
- Delete legacy IdP managers (Auth0, Azure, Keycloak, Okta, Google
  Workspace, JumpCloud, Authentik, PocketId) - only Zitadel supported
2025-12-19 17:58:49 +01:00
Zoltan Papp
537151e0f3 Remove redundant lock in peer update logic to avoid deadlock with exported functions (#4953) 2025-12-17 13:55:33 +01:00
Zoltan Papp
a9c28ef723 Add stack trace for bundle (#4957) 2025-12-17 13:49:02 +01:00
48 changed files with 102060 additions and 5804 deletions

View File

@@ -56,6 +56,7 @@ block.prof: Block profiling information.
heap.prof: Heap profiling information (snapshot of memory allocations).
allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
Anonymization Process
@@ -109,6 +110,9 @@ go tool pprof -http=:8088 heap.prof
This will open a web browser tab with the profiling information.
Stack Trace
The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created.
Routes
The routes.txt file contains detailed routing table information in a tabular format:
@@ -327,6 +331,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add profiles to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
if err := g.addSyncResponse(); err != nil {
return fmt.Errorf("add sync response: %w", err)
}
@@ -522,6 +530,18 @@ func (g *BundleGenerator) addProf() (err error) {
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)
stackTrace := bytes.NewReader(buf[:n])
if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil {
return fmt.Errorf("add stack trace file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addInterfaces() error {
interfaces, err := net.Interfaces()
if err != nil {

View File

@@ -20,7 +20,7 @@ type EndpointUpdater struct {
wgConfig WgConfig
initiator bool
// mu protects updateWireGuardPeer and cancelFunc
// mu protects cancelFunc
mu sync.Mutex
cancelFunc func()
updateWg sync.WaitGroup
@@ -86,11 +86,9 @@ func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.U
case <-ctx.Done():
return
case <-t.C:
e.mu.Lock()
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
}
e.mu.Unlock()
}
}

View File

@@ -0,0 +1,352 @@
# Zitadel Integration Plan
This plan is split into stages. Each stage is self-contained and can be implemented and tested independently.
---
## Stage 1: Infrastructure Setup [COMPLETED]
**Goal**: Add Zitadel to docker-compose with first-boot initialization (single container, SQLite).
### Design Principles
- **KISS**: Single Zitadel container with SQLite (no separate PostgreSQL or init containers)
- **DRY**: Leverage Zitadel's built-in first-instance configuration via environment variables
- **Clean**: Credentials generated in configure.sh and printed directly
### Files Created
#### `infrastructure_files/Caddyfile.tmpl`
Reverse proxy configuration routing:
- `/api/*`, `/management.*` → Management server
- `/relay*` → Relay
- `/signalexchange.*`, `/signal*` → Signal
- `/oauth/*`, `/oidc/*`, `/.well-known/*`, `/ui/*` → Zitadel
- `/*` → Dashboard
### Files Modified
#### `infrastructure_files/docker-compose.yml.tmpl`
Added single Zitadel service using SQLite:
```yaml
zitadel:
image: ghcr.io/zitadel/zitadel:$ZITADEL_TAG
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
environment:
- ZITADEL_DATABASE_SQLITE_PATH=/data/zitadel.db
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_USERNAME=$ZITADEL_ADMIN_USERNAME
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_PASSWORD=$ZITADEL_ADMIN_PASSWORD
# ... service account config via env vars
```
#### `infrastructure_files/management.json.tmpl`
Updated IdpManagerConfig to default to Zitadel.
#### `infrastructure_files/base.setup.env`
Added Zitadel-specific environment variables:
- `ZITADEL_TAG` (default: v4.7.6)
- `ZITADEL_MASTERKEY` (auto-generated)
- `ZITADEL_ADMIN_USERNAME`, `ZITADEL_ADMIN_PASSWORD` (auto-generated)
- `ZITADEL_EXTERNALSECURE`, `ZITADEL_EXTERNALPORT`, `ZITADEL_TLS_MODE`
#### `infrastructure_files/configure.sh`
- Generates Zitadel masterkey if not set
- Generates admin credentials if not set
- Prints credentials to stdout during configuration
### Deliverable
Running `./configure.sh && cd artifacts && docker compose up -d` starts full stack with Zitadel. Admin credentials printed during configuration.
---
## Stage 2: Simplify IdP Integration
**Goal**: Make NetBird the single source of truth for user authorization data. Zitadel handles authentication only.
### Design Principles
- **Single source of truth**: NetBird DB stores all authorization data (roles, account membership, invite status)
- **Clean separation**: Zitadel stores identity only (email, name, password)
- **No duplicate data**: Remove `wt_account_id` and `wt_pending_invite` from IdP metadata
### Files to Modify
#### `management/server/types/user.go`
Add `PendingInvite` field to User struct:
```go
type User struct {
// ... existing fields ...
PendingInvite bool // NEW: tracks if user has accepted invite
}
```
Update `ToUserInfo()` to use local field instead of IdP metadata:
```go
// Before
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
userStatus = UserStatusInvited
}
// After
if u.PendingInvite {
userStatus = UserStatusInvited
}
```
#### `management/server/idp/idp.go`
Simplify `Manager` interface:
```go
type Manager interface {
// Simplified - no more accountID or invitedByEmail params
CreateUser(ctx context.Context, email, name string) (*UserData, error)
GetUserDataByID(ctx context.Context, userId string) (*UserData, error)
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
InviteUserByID(ctx context.Context, userID string) error
DeleteUser(ctx context.Context, userID string) error
}
```
Remove:
- `UpdateUserAppMetadata()` - no longer needed
- `GetAccount()` - account membership now in NetBird DB
- `GetAllAccounts()` - account membership now in NetBird DB
- `AppMetadata` struct fields (`WTAccountID`, `WTPendingInvite`, `WTInvitedBy`)
#### `management/server/idp/zitadel.go`
- Update `CreateUser()` to not set metadata
- Remove `UpdateUserAppMetadata()` implementation
- Simplify `GetUserDataByID()` to not expect metadata
#### `management/server/user.go`
Update user creation flow:
1. `inviteNewUser()`: Set `PendingInvite = true` when creating user
2. On first login: Set `PendingInvite = false`
### Data Flow (After)
```
Invite User:
1. Admin invites user@example.com via NetBird UI
2. NetBird calls Zitadel: CreateUser(email, name) // No metadata
3. Zitadel creates user, sends invite email
4. NetBird creates User in its DB:
- Id = Zitadel user ID
- AccountID = current account
- Role = invited role
- PendingInvite = true
First Login:
1. User clicks invite, sets password in Zitadel
2. User logs into NetBird Dashboard
3. NetBird updates User in its DB:
- PendingInvite = false
- LastLogin = now
```
### Deliverable
- NetBird is single source of truth for authorization
- Zitadel stores identity only (email, name, credentials)
- No metadata sync issues between systems
---
## Stage 3: Remove Legacy IdP Managers
**Goal**: Remove all non-Zitadel IdP implementations.
### Files to Delete
- `management/server/idp/auth0.go`
- `management/server/idp/auth0_test.go`
- `management/server/idp/azure.go`
- `management/server/idp/azure_test.go`
- `management/server/idp/keycloak.go`
- `management/server/idp/keycloak_test.go`
- `management/server/idp/okta.go`
- `management/server/idp/okta_test.go`
- `management/server/idp/google_workspace.go`
- `management/server/idp/google_workspace_test.go`
- `management/server/idp/jumpcloud.go`
- `management/server/idp/jumpcloud_test.go`
- `management/server/idp/authentik.go`
- `management/server/idp/authentik_test.go`
- `management/server/idp/pocketid.go`
- `management/server/idp/pocketid_test.go`
### Files to Modify
#### `management/server/idp/idp.go`
Simplify `NewManager()` to only support Zitadel:
```go
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
switch strings.ToLower(config.ManagerType) {
case "none", "":
return nil, nil
case "zitadel":
return NewZitadelManager(...)
default:
return nil, fmt.Errorf("unsupported IdP manager type: %s (only 'zitadel' is supported)", config.ManagerType)
}
}
```
#### `management/server/idp/idp.go`
Remove unused config structs:
- `Auth0ClientConfig`
- `AzureClientConfig`
- `KeycloakClientConfig`
- etc.
### Deliverable
Only Zitadel IdP manager remains. Build passes. Tests pass.
---
## Stage 4: External IdP Connector API
**Goal**: API for users to add external IdPs (Okta, Google, LDAP) as Zitadel connectors.
### New Files
#### `management/server/idp/connectors.go`
Wrapper for Zitadel IdP connector API:
```go
type Connector struct {
ID string
Name string
Type string // "oidc", "ldap", "saml"
Issuer string // for OIDC
ClientID string // for OIDC
}
type ConnectorManager interface {
AddOIDCConnector(ctx, name, issuer, clientID, clientSecret string, scopes []string) (*Connector, error)
AddLDAPConnector(ctx, name, host string, port int, baseDN, bindDN, bindPassword string) (*Connector, error)
ListConnectors(ctx) ([]Connector, error)
DeleteConnector(ctx, connectorID string) error
}
```
#### `management/server/http/handlers/idp_connectors_handler.go`
REST API handlers:
```
POST /api/idp/connectors - Add connector
GET /api/idp/connectors - List connectors
DELETE /api/idp/connectors/{id} - Delete connector
```
### Zitadel API Endpoints Used
- `POST /management/v1/idps/generic_oidc` - Add OIDC connector
- `POST /management/v1/idps/ldap` - Add LDAP connector
- `GET /management/v1/idps` - List all IdPs
- `DELETE /management/v1/idps/{id}` - Remove IdP
### Deliverable
Admin users can add/remove external IdP connectors via NetBird API.
---
## Stage 5: User Role Permissions
**Goal**: Enforce admin/user permissions throughout the application.
### Files to Modify
#### `management/server/types/user.go`
Simplify roles to:
```go
const (
UserRoleAdmin UserRole = "admin"
UserRoleUser UserRole = "user"
)
```
Keep `owner` as alias for `admin` for backwards compatibility.
#### Permission checks (various files)
Update permission checks to use simplified role model:
- Admin: Full access to all resources
- User: View/manage own peers only, no user/policy management
### Permission Matrix
| Resource | Admin | User |
|----------|-------|------|
| All peers | read/write | - |
| Own peers | read/write | read/write |
| Users | read/write | - |
| Groups | read/write | read (own) |
| Policies | read/write | - |
| IdP Connectors | read/write | - |
| Account settings | read/write | - |
### Deliverable
Role-based access control enforced. User role has limited dashboard access.
---
## Architecture Diagram
```
┌─────────────────────────────────────────────────────────────────┐
│ NETBIRD DEPLOYMENT │
│ │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ CADDY (Reverse Proxy) │ │
│ │ :80/:443 → routes to appropriate service │ │
│ └───────────────────────────────────────────────────────────┘ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────────┐ │
│ │ Dashboard│ │Management│ │ Signal │ │ Zitadel │ │
│ │ :80 │ │ :80 │ │ :80 │ │ :8080 (SQLite) │ │
│ └──────────┘ └────┬─────┘ └──────────┘ └──────────────────┘ │
│ │ │ │
│ │ OIDC + Mgmt API │ │
│ └──────────────────────────┘ │
│ │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ Relay │ Coturn (TURN) │ │
│ └───────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
```
---
## Implementation Order
1. **Stage 1** - Infrastructure (can be tested standalone)
2. **Stage 2** - Role management (requires Stage 1)
3. **Stage 3** - Remove legacy IdPs (can be done in parallel with Stage 2)
4. **Stage 4** - Connector API (requires Stage 1)
5. **Stage 5** - Permissions (requires Stage 2)
Stages 3 and 4 can be worked on in parallel after Stage 1 is complete.
---
## Security Considerations
### First-Boot Credentials
- Generated with `openssl rand -base64 32` in configure.sh
- Printed to stdout during configuration (save them!)
- Marked as "must change on first login" in Zitadel via `PASSWORDCHANGEREQUIRED=true`
### Zitadel Masterkey
- Auto-generated if not provided (32 bytes)
- Stored in environment variable (docker secret in production)
- Used for Zitadel's internal encryption
### Service Account
- Machine user created via `ZITADEL_FIRSTINSTANCE_ORG_MACHINE_*` env vars
- PAT (Personal Access Token) generated with 1-year expiration
- Used by NetBird management to call Zitadel API
---
## Migration Notes
Existing deployments using Auth0/Azure/Keycloak will need to:
1. Deploy Zitadel alongside existing IdP
2. Add existing IdP as Zitadel connector
3. Migrate users (or let them re-authenticate via connector)
4. Switch NetBird config to use Zitadel
5. Remove old IdP
A migration guide should be provided as documentation.

View File

@@ -0,0 +1,83 @@
{
servers :80,:443 {
protocols h1 h2c h2 h3
}
}
(security_headers) {
header * {
# HSTS - use 1 hour for testing, increase to 63072000 (2 years) in production
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
# Prevent MIME type sniffing
X-Content-Type-Options "nosniff"
# Clickjacking protection
X-Frame-Options "SAMEORIGIN"
# XSS protection
X-XSS-Protection "1; mode=block"
# Remove server header
-Server
# Referrer policy
Referrer-Policy strict-origin-when-cross-origin
}
}
:${NETBIRD_CADDY_PORT}${CADDY_SECURE_DOMAIN} {
import security_headers
# Relay
reverse_proxy /relay* relay:${NETBIRD_RELAY_INTERNAL_PORT}
# Signal - WebSocket proxy
reverse_proxy /ws-proxy/signal* signal:80
# Signal - gRPC
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
# Management - REST API
reverse_proxy /api/* management:80
# Management - WebSocket proxy
reverse_proxy /ws-proxy/management* management:80
# Management - gRPC
reverse_proxy /management.ManagementService/* h2c://management:80
# Zitadel - Admin API
reverse_proxy /zitadel.admin.v1.AdminService/* h2c://zitadel:8080
reverse_proxy /admin/v1/* h2c://zitadel:8080
# Zitadel - Auth API
reverse_proxy /zitadel.auth.v1.AuthService/* h2c://zitadel:8080
reverse_proxy /auth/v1/* h2c://zitadel:8080
# Zitadel - Management API
reverse_proxy /zitadel.management.v1.ManagementService/* h2c://zitadel:8080
reverse_proxy /management/v1/* h2c://zitadel:8080
# Zitadel - System API
reverse_proxy /zitadel.system.v1.SystemService/* h2c://zitadel:8080
reverse_proxy /system/v1/* h2c://zitadel:8080
# Zitadel - User API v2
reverse_proxy /zitadel.user.v2.UserService/* h2c://zitadel:8080
# Zitadel - Assets
reverse_proxy /assets/v1/* h2c://zitadel:8080
# Zitadel - UI (login, console, etc.)
reverse_proxy /ui/* h2c://zitadel:8080
# Zitadel - OIDC endpoints
reverse_proxy /oidc/v1/* h2c://zitadel:8080
reverse_proxy /oauth/v2/* h2c://zitadel:8080
reverse_proxy /.well-known/openid-configuration h2c://zitadel:8080
# Zitadel - SAML
reverse_proxy /saml/v2/* h2c://zitadel:8080
# Zitadel - Other
reverse_proxy /openapi/* h2c://zitadel:8080
reverse_proxy /debug/* h2c://zitadel:8080
reverse_proxy /device/* h2c://zitadel:8080
reverse_proxy /device h2c://zitadel:8080
# Dashboard - catch-all for frontend
reverse_proxy /* dashboard:80
}

View File

@@ -2,29 +2,20 @@
# Management API
# Management API port
NETBIRD_MGMT_API_PORT=${NETBIRD_MGMT_API_PORT:-33073}
# Management API endpoint address, used by the Dashboard
NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
# Management Certificate file path. These are generated by the Dashboard container
NETBIRD_LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN
NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/fullchain.pem"
# Management Certificate key file path.
NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/privkey.pem"
# Management API endpoint address, used by the Dashboard (Caddy handles TLS)
NETBIRD_MGMT_API_ENDPOINT=${NETBIRD_HTTP_PROTOCOL:-https}://$NETBIRD_DOMAIN
# By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted}
NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=${NETBIRD_MGMT_IDP_SIGNKEY_REFRESH:-false}
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=${NETBIRD_MGMT_DISABLE_DEFAULT_POLICY:-false}
# Signal
NETBIRD_SIGNAL_PROTOCOL="http"
NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-10000}
NETBIRD_SIGNAL_PROTOCOL=${NETBIRD_HTTP_PROTOCOL:-https}
NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-443}
# Relay
NETBIRD_RELAY_DOMAIN=${NETBIRD_RELAY_DOMAIN:-$NETBIRD_DOMAIN}
NETBIRD_RELAY_PORT=${NETBIRD_RELAY_PORT:-33080}
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT:-rel://$NETBIRD_RELAY_DOMAIN:$NETBIRD_RELAY_PORT}
# Relay (internal port for Caddy reverse proxy)
NETBIRD_RELAY_INTERNAL_PORT=${NETBIRD_RELAY_INTERNAL_PORT:-80}
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT:-${NETBIRD_RELAY_PROTO:-rels}://$NETBIRD_DOMAIN:${NETBIRD_RELAY_PORT:-443}}
# Relay auth secret
NETBIRD_RELAY_AUTH_SECRET=
@@ -141,3 +132,57 @@ export NETBIRD_RELAY_ENDPOINT
export NETBIRD_RELAY_AUTH_SECRET
export NETBIRD_RELAY_TAG
export NETBIRD_MGMT_DISABLE_DEFAULT_POLICY
# Zitadel IdP Configuration
ZITADEL_TAG=${ZITADEL_TAG:-"v4.7.6"}
# Zitadel masterkey (32 bytes, auto-generated if not set)
ZITADEL_MASTERKEY=
# Zitadel admin credentials (auto-generated if not set)
ZITADEL_ADMIN_USERNAME=
ZITADEL_ADMIN_PASSWORD=
# Zitadel external configuration
ZITADEL_EXTERNALSECURE=${ZITADEL_EXTERNALSECURE:-true}
ZITADEL_EXTERNALPORT=${ZITADEL_EXTERNALPORT:-443}
ZITADEL_TLS_MODE=${ZITADEL_TLS_MODE:-external}
# Zitadel PAT expiration (1 year from startup)
ZITADEL_PAT_EXPIRATION=
# Zitadel management endpoint
ZITADEL_MANAGEMENT_ENDPOINT=${NETBIRD_HTTP_PROTOCOL:-https}://$NETBIRD_DOMAIN/management/v1
# HTTP protocol (http or https)
NETBIRD_HTTP_PROTOCOL=${NETBIRD_HTTP_PROTOCOL:-https}
# Caddy configuration
NETBIRD_CADDY_PORT=${NETBIRD_CADDY_PORT:-80}
CADDY_SECURE_DOMAIN=
# Zitadel OIDC endpoints
NETBIRD_AUTH_AUTHORITY=${NETBIRD_HTTP_PROTOCOL:-https}://$NETBIRD_DOMAIN
NETBIRD_AUTH_TOKEN_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/oauth/v2/token
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/.well-known/openid-configuration
NETBIRD_AUTH_JWT_CERTS=${NETBIRD_AUTH_AUTHORITY}/.well-known/jwks.json
NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/oauth/v2/authorize
NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/oauth/v2/device_authorization
NETBIRD_AUTH_USER_ID_CLAIM=${NETBIRD_AUTH_USER_ID_CLAIM:-sub}
NETBIRD_AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES:-"openid profile email offline_access"}
# Zitadel exports
export ZITADEL_TAG
export ZITADEL_MASTERKEY
export ZITADEL_ADMIN_USERNAME
export ZITADEL_ADMIN_PASSWORD
export ZITADEL_EXTERNALSECURE
export ZITADEL_EXTERNALPORT
export ZITADEL_TLS_MODE
export ZITADEL_PAT_EXPIRATION
export ZITADEL_MANAGEMENT_ENDPOINT
export NETBIRD_HTTP_PROTOCOL
export NETBIRD_CADDY_PORT
export CADDY_SECURE_DOMAIN
export NETBIRD_AUTH_AUTHORITY
export NETBIRD_AUTH_TOKEN_ENDPOINT
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT
export NETBIRD_AUTH_JWT_CERTS
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT
export NETBIRD_AUTH_USER_ID_CLAIM
export NETBIRD_AUTH_SUPPORTED_SCOPES
export NETBIRD_RELAY_INTERNAL_PORT

View File

@@ -1,261 +1,137 @@
#!/bin/bash
set -e
if ! which curl >/dev/null 2>&1; then
echo "This script uses curl fetch OpenID configuration from IDP."
echo "Please install curl and re-run the script https://curl.se/"
echo ""
exit 1
fi
if ! which jq >/dev/null 2>&1; then
echo "This script uses jq to load OpenID configuration from IDP."
echo "Please install jq and re-run the script https://stedolan.github.io/jq/"
echo ""
exit 1
fi
# Check required dependencies
for cmd in curl jq envsubst openssl; do
if ! which $cmd >/dev/null 2>&1; then
echo "This script requires $cmd. Please install it and re-run."
exit 1
fi
done
# Source configuration
source setup.env
source base.setup.env
if ! which envsubst >/dev/null 2>&1; then
echo "envsubst is needed to run this script"
if [[ $(uname) == "Darwin" ]]; then
echo "you can install it with homebrew (https://brew.sh):"
echo "brew install gettext"
else
if which apt-get >/dev/null 2>&1; then
echo "you can install it by running"
echo "apt-get update && apt-get install gettext-base"
else
echo "you can install it by installing the package gettext with your package manager"
fi
fi
# Validate required variables
if [[ -z "$NETBIRD_DOMAIN" ]]; then
echo "NETBIRD_DOMAIN is not set, please update your setup.env file"
exit 1
fi
if [[ "x-$NETBIRD_DOMAIN" == "x-" ]]; then
echo NETBIRD_DOMAIN is not set, please update your setup.env file
echo If you are migrating from old versions, you might need to update your variables prefixes from
echo WIRETRUSTEE_.. TO NETBIRD_
# Check database configuration if using external database
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "postgres" && -z "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" ]]; then
echo "Error: NETBIRD_STORE_CONFIG_ENGINE=postgres but NETBIRD_STORE_ENGINE_POSTGRES_DSN is not set."
exit 1
fi
# Check if PostgreSQL is set as the store engine
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "postgres" ]]; then
# Exit if 'NETBIRD_STORE_ENGINE_POSTGRES_DSN' is not set
if [[ -z "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" ]]; then
echo "Warning: NETBIRD_STORE_CONFIG_ENGINE=postgres but NETBIRD_STORE_ENGINE_POSTGRES_DSN is not set."
echo "Please add the following line to your setup.env file:"
echo 'NETBIRD_STORE_ENGINE_POSTGRES_DSN="host=<PG_HOST> user=<PG_USER> password=<PG_PASSWORD> dbname=<PG_DB_NAME> port=<PG_PORT>"'
exit 1
fi
export NETBIRD_STORE_ENGINE_POSTGRES_DSN
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "mysql" && -z "$NETBIRD_STORE_ENGINE_MYSQL_DSN" ]]; then
echo "Error: NETBIRD_STORE_CONFIG_ENGINE=mysql but NETBIRD_STORE_ENGINE_MYSQL_DSN is not set."
exit 1
fi
# Check if MySQL is set as the store engine
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "mysql" ]]; then
# Exit if 'NETBIRD_STORE_ENGINE_MYSQL_DSN' is not set
if [[ -z "$NETBIRD_STORE_ENGINE_MYSQL_DSN" ]]; then
echo "Warning: NETBIRD_STORE_CONFIG_ENGINE=mysql but NETBIRD_STORE_ENGINE_MYSQL_DSN is not set."
echo "Please add the following line to your setup.env file:"
echo 'NETBIRD_STORE_ENGINE_MYSQL_DSN="<username>:<password>@tcp(127.0.0.1:3306)/<database>"'
exit 1
fi
export NETBIRD_STORE_ENGINE_MYSQL_DSN
fi
# local development or tests
# Configure for local development vs production
if [[ $NETBIRD_DOMAIN == "localhost" || $NETBIRD_DOMAIN == "127.0.0.1" ]]; then
export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN="netbird.selfhosted"
export NETBIRD_MGMT_API_ENDPOINT=http://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
unset NETBIRD_MGMT_API_CERT_FILE
unset NETBIRD_MGMT_API_CERT_KEY_FILE
fi
# if not provided, we generate a turn password
if [[ "x-$TURN_PASSWORD" == "x-" ]]; then
export TURN_PASSWORD=$(openssl rand -base64 32 | sed 's/=//g')
fi
TURN_EXTERNAL_IP_CONFIG="#"
if [[ "x-$NETBIRD_TURN_EXTERNAL_IP" == "x-" ]]; then
echo "discovering server's public IP"
IP=$(curl -s -4 https://jsonip.com | jq -r '.ip')
if [[ "x-$IP" != "x-" ]]; then
TURN_EXTERNAL_IP_CONFIG="external-ip=$IP"
else
echo "unable to discover server's public IP"
fi
export NETBIRD_MGMT_API_ENDPOINT="http://$NETBIRD_DOMAIN"
export NETBIRD_HTTP_PROTOCOL="http"
export ZITADEL_EXTERNALSECURE="false"
export ZITADEL_EXTERNALPORT="80"
export ZITADEL_TLS_MODE="disabled"
export NETBIRD_RELAY_PROTO="rel"
else
echo "${NETBIRD_TURN_EXTERNAL_IP}"| egrep '([0-9]{1,3}\.){3}[0-9]{1,3}$' > /dev/null
if [[ $? -eq 0 ]]; then
echo "using provided server's public IP"
TURN_EXTERNAL_IP_CONFIG="external-ip=$NETBIRD_TURN_EXTERNAL_IP"
else
echo "provided NETBIRD_TURN_EXTERNAL_IP $NETBIRD_TURN_EXTERNAL_IP is invalid, please correct it and try again"
exit 1
fi
export NETBIRD_HTTP_PROTOCOL="https"
export ZITADEL_EXTERNALSECURE="true"
export ZITADEL_EXTERNALPORT="443"
export ZITADEL_TLS_MODE="external"
export NETBIRD_RELAY_PROTO="rels"
export CADDY_SECURE_DOMAIN=", $NETBIRD_DOMAIN:443"
fi
# Auto-generate secrets if not provided
[[ -z "$TURN_PASSWORD" ]] && export TURN_PASSWORD=$(openssl rand -base64 32 | sed 's/=//g')
[[ -z "$NETBIRD_RELAY_AUTH_SECRET" ]] && export NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed 's/=//g')
[[ -z "$ZITADEL_MASTERKEY" ]] && export ZITADEL_MASTERKEY=$(openssl rand -base64 32 | head -c 32)
# Generate Zitadel admin credentials if not provided
if [[ -z "$ZITADEL_ADMIN_USERNAME" ]]; then
export ZITADEL_ADMIN_USERNAME="admin@${NETBIRD_DOMAIN}"
fi
if [[ -z "$ZITADEL_ADMIN_PASSWORD" ]]; then
export ZITADEL_ADMIN_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')!"
fi
# Set Zitadel PAT expiration (1 year from now)
if [[ "$OSTYPE" == "darwin"* ]]; then
export ZITADEL_PAT_EXPIRATION=$(date -u -v+1y "+%Y-%m-%dT%H:%M:%SZ")
else
export ZITADEL_PAT_EXPIRATION=$(date -u -d "+1 year" "+%Y-%m-%dT%H:%M:%SZ")
fi
# Discover external IP for TURN
TURN_EXTERNAL_IP_CONFIG="#"
if [[ -z "$NETBIRD_TURN_EXTERNAL_IP" ]]; then
IP=$(curl -s -4 https://jsonip.com | jq -r '.ip' 2>/dev/null || echo "")
[[ -n "$IP" ]] && TURN_EXTERNAL_IP_CONFIG="external-ip=$IP"
elif echo "$NETBIRD_TURN_EXTERNAL_IP" | grep -qE '^([0-9]{1,3}\.){3}[0-9]{1,3}$'; then
TURN_EXTERNAL_IP_CONFIG="external-ip=$NETBIRD_TURN_EXTERNAL_IP"
fi
export TURN_EXTERNAL_IP_CONFIG
# if not provided, we generate a relay auth secret
if [[ "x-$NETBIRD_RELAY_AUTH_SECRET" == "x-" ]]; then
export NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed 's/=//g')
fi
artifacts_path="./artifacts"
mkdir -p $artifacts_path
# Configure endpoints
export NETBIRD_AUTH_AUTHORITY="${NETBIRD_HTTP_PROTOCOL}://${NETBIRD_DOMAIN}"
export NETBIRD_AUTH_TOKEN_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/oauth/v2/token"
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/.well-known/openid-configuration"
export NETBIRD_AUTH_JWT_CERTS="${NETBIRD_AUTH_AUTHORITY}/.well-known/jwks.json"
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/oauth/v2/authorize"
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/oauth/v2/device_authorization"
export ZITADEL_MANAGEMENT_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/management/v1"
export NETBIRD_RELAY_ENDPOINT="${NETBIRD_RELAY_PROTO}://${NETBIRD_DOMAIN}:${ZITADEL_EXTERNALPORT}"
# Volume names (with backwards compatibility)
MGMT_VOLUMENAME="${VOLUME_PREFIX}${MGMT_VOLUMESUFFIX}"
SIGNAL_VOLUMENAME="${VOLUME_PREFIX}${SIGNAL_VOLUMESUFFIX}"
LETSENCRYPT_VOLUMENAME="${VOLUME_PREFIX}${LETSENCRYPT_VOLUMESUFFIX}"
# if volume with wiretrustee- prefix already exists, use it, else create new with netbird-
OLD_PREFIX='wiretrustee-'
if docker volume ls | grep -q "${OLD_PREFIX}${MGMT_VOLUMESUFFIX}"; then
MGMT_VOLUMENAME="${OLD_PREFIX}${MGMT_VOLUMESUFFIX}"
fi
if docker volume ls | grep -q "${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}"; then
SIGNAL_VOLUMENAME="${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}"
fi
if docker volume ls | grep -q "${OLD_PREFIX}${LETSENCRYPT_VOLUMESUFFIX}"; then
LETSENCRYPT_VOLUMENAME="${OLD_PREFIX}${LETSENCRYPT_VOLUMESUFFIX}"
docker volume ls 2>/dev/null | grep -q "${OLD_PREFIX}${MGMT_VOLUMESUFFIX}" && MGMT_VOLUMENAME="${OLD_PREFIX}${MGMT_VOLUMESUFFIX}"
docker volume ls 2>/dev/null | grep -q "${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}" && SIGNAL_VOLUMENAME="${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}"
export MGMT_VOLUMENAME SIGNAL_VOLUMENAME
# Preserve existing encryption key
if test -f 'management.json'; then
encKey=$(jq -r ".DataStoreEncryptionKey" management.json 2>/dev/null || echo "null")
[[ "$encKey" != "null" && -n "$encKey" ]] && export NETBIRD_DATASTORE_ENC_KEY="$encKey"
fi
export MGMT_VOLUMENAME
export SIGNAL_VOLUMENAME
export LETSENCRYPT_VOLUMENAME
#backwards compatibility after migrating to generic OIDC with Auth0
if [[ -z "${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT}" ]]; then
if [[ -z "${NETBIRD_AUTH0_DOMAIN}" ]]; then
# not a backward compatible state
echo "NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT property must be set in the setup.env file"
exit 1
fi
echo "It seems like you provided an old setup.env file."
echo "Since the release of v0.8.10, we introduced a new set of properties."
echo "The script is backward compatible and will continue automatically."
echo "In the future versions it will be deprecated. Please refer to the documentation to learn about the changes http://netbird.io/docs/getting-started/self-hosting"
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="https://${NETBIRD_AUTH0_DOMAIN}/.well-known/openid-configuration"
export NETBIRD_USE_AUTH0="true"
export NETBIRD_AUTH_AUDIENCE=${NETBIRD_AUTH0_AUDIENCE}
export NETBIRD_AUTH_CLIENT_ID=${NETBIRD_AUTH0_CLIENT_ID}
fi
echo "loading OpenID configuration from ${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT} to the openid-configuration.json file"
curl "${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT}" -q -o ${artifacts_path}/openid-configuration.json
export NETBIRD_AUTH_AUTHORITY=$(jq -r '.issuer' ${artifacts_path}/openid-configuration.json)
export NETBIRD_AUTH_JWT_CERTS=$(jq -r '.jwks_uri' ${artifacts_path}/openid-configuration.json)
export NETBIRD_AUTH_TOKEN_ENDPOINT=$(jq -r '.token_endpoint' ${artifacts_path}/openid-configuration.json)
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=$(jq -r '.device_authorization_endpoint' ${artifacts_path}/openid-configuration.json)
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT=$(jq -r '.authorization_endpoint' ${artifacts_path}/openid-configuration.json)
if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then
# user enabled Device Authorization Grant feature
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted"
fi
if [ "$NETBIRD_TOKEN_SOURCE" = "idToken" ]; then
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN=true
fi
# Check if letsencrypt was disabled
if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then
export NETBIRD_DASHBOARD_ENDPOINT="https://$NETBIRD_DOMAIN:443"
export NETBIRD_SIGNAL_ENDPOINT="https://$NETBIRD_DOMAIN:$NETBIRD_SIGNAL_PORT"
export NETBIRD_RELAY_ENDPOINT="rels://$NETBIRD_DOMAIN:$NETBIRD_RELAY_PORT/relay"
echo "Letsencrypt was disabled, the Https-endpoints cannot be used anymore"
echo " and a reverse-proxy with Https needs to be placed in front of netbird!"
echo "The following forwards have to be setup:"
echo "- $NETBIRD_DASHBOARD_ENDPOINT -http-> dashboard:80"
echo "- $NETBIRD_MGMT_API_ENDPOINT/api -http-> management:$NETBIRD_MGMT_API_PORT"
echo "- $NETBIRD_MGMT_API_ENDPOINT/management.ManagementService/ -grpc-> management:$NETBIRD_MGMT_API_PORT"
echo "- $NETBIRD_SIGNAL_ENDPOINT/signalexchange.SignalExchange/ -grpc-> signal:80"
echo "- $NETBIRD_RELAY_ENDPOINT/ -http-> relay:33080"
echo "You most likely also have to change NETBIRD_MGMT_API_ENDPOINT in base.setup.env and port-mappings in docker-compose.yml.tmpl and rerun this script."
echo " The target of the forwards depends on your setup. Beware of the gRPC protocol instead of http for management and signal!"
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
echo ""
unset NETBIRD_LETSENCRYPT_DOMAIN
unset NETBIRD_MGMT_API_CERT_FILE
unset NETBIRD_MGMT_API_CERT_KEY_FILE
fi
if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then
export NETBIRD_SIGNAL_PROTOCOL="https"
fi
# Check if management identity provider is set
if [ -n "$NETBIRD_MGMT_IDP" ]; then
EXTRA_CONFIG={}
# extract extra config from all env prefixed with NETBIRD_IDP_MGMT_EXTRA_
for var in ${!NETBIRD_IDP_MGMT_EXTRA_*}; do
# convert key snake case to camel case
key=$(
echo "${var#NETBIRD_IDP_MGMT_EXTRA_}" | awk -F "_" \
'{for (i=1; i<=NF; i++) {output=output substr($i,1,1) tolower(substr($i,2))} print output}'
)
value="${!var}"
echo "$var"
EXTRA_CONFIG=$(jq --arg k "$key" --arg v "$value" '.[$k] = $v' <<<"$EXTRA_CONFIG")
done
export NETBIRD_MGMT_IDP
export NETBIRD_IDP_MGMT_CLIENT_ID
export NETBIRD_IDP_MGMT_CLIENT_SECRET
export NETBIRD_IDP_MGMT_EXTRA_CONFIG=$EXTRA_CONFIG
else
export NETBIRD_IDP_MGMT_EXTRA_CONFIG={}
fi
IFS=',' read -r -a REDIRECT_URL_PORTS <<< "$NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS"
REDIRECT_URLS=""
for port in "${REDIRECT_URL_PORTS[@]}"; do
REDIRECT_URLS+="\"http://localhost:${port}\","
# Create artifacts directory and backup existing files
artifacts_path="./artifacts"
mkdir -p "$artifacts_path"
bkp_postfix="$(date +%s)"
for file in docker-compose.yml management.json turnserver.conf Caddyfile; do
[[ -f "${artifacts_path}/${file}" ]] && cp "${artifacts_path}/${file}" "${artifacts_path}/${file}.bkp.${bkp_postfix}"
done
export NETBIRD_AUTH_PKCE_REDIRECT_URLS=${REDIRECT_URLS%,}
# Generate configuration files
envsubst < docker-compose.yml.tmpl > "$artifacts_path/docker-compose.yml"
envsubst < management.json.tmpl | jq . > "$artifacts_path/management.json"
envsubst < turnserver.conf.tmpl > "$artifacts_path/turnserver.conf"
envsubst < Caddyfile.tmpl > "$artifacts_path/Caddyfile"
# Remove audience for providers that do not support it
if [ "$NETBIRD_DASH_AUTH_USE_AUDIENCE" = "false" ]; then
export NETBIRD_DASH_AUTH_AUDIENCE=none
export NETBIRD_AUTH_PKCE_AUDIENCE=
fi
# Read the encryption key
if test -f 'management.json'; then
encKey=$(jq -r ".DataStoreEncryptionKey" management.json)
if [[ "$encKey" != "null" ]]; then
export NETBIRD_DATASTORE_ENC_KEY=$encKey
fi
fi
env | grep NETBIRD
bkp_postfix="$(date +%s)"
if test -f "${artifacts_path}/docker-compose.yml"; then
cp $artifacts_path/docker-compose.yml "${artifacts_path}/docker-compose.yml.bkp.${bkp_postfix}"
fi
if test -f "${artifacts_path}/management.json"; then
cp $artifacts_path/management.json "${artifacts_path}/management.json.bkp.${bkp_postfix}"
fi
if test -f "${artifacts_path}/turnserver.conf"; then
cp ${artifacts_path}/turnserver.conf "${artifacts_path}/turnserver.conf.bkp.${bkp_postfix}"
fi
envsubst <docker-compose.yml.tmpl >$artifacts_path/docker-compose.yml
envsubst <management.json.tmpl | jq . >$artifacts_path/management.json
envsubst <turnserver.conf.tmpl >$artifacts_path/turnserver.conf
# Print summary
echo ""
echo "=========================================="
echo " NetBird Configuration Complete"
echo "=========================================="
echo " Domain: $NETBIRD_DOMAIN"
echo " Protocol: $NETBIRD_HTTP_PROTOCOL"
echo " Zitadel: $ZITADEL_TAG (SQLite)"
echo "=========================================="
echo ""
echo " ADMIN CREDENTIALS (save these!):"
echo " Username: $ZITADEL_ADMIN_USERNAME"
echo " Password: $ZITADEL_ADMIN_PASSWORD"
echo ""
echo "=========================================="
echo ""
echo "To start NetBird:"
echo " cd $artifacts_path && docker compose up -d"
echo ""

View File

@@ -7,108 +7,137 @@ x-default: &default
max-file: '2'
services:
# Caddy reverse proxy
caddy:
<<: *default
image: caddy:2
networks: [netbird]
ports:
- '443:443'
- '443:443/udp'
- '80:80'
volumes:
- netbird-caddy-data:/data
- ./Caddyfile:/etc/caddy/Caddyfile:ro
depends_on:
zitadel:
condition: service_healthy
# UI dashboard
dashboard:
<<: *default
image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG
ports:
- 80:80
- 443:443
networks: [netbird]
environment:
# Endpoints
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
- NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
# OIDC
- AUTH_AUDIENCE=$NETBIRD_DASH_AUTH_AUDIENCE
- AUTH_AUDIENCE=$NETBIRD_AUTH_CLIENT_ID
- AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
- AUTH_CLIENT_SECRET=$NETBIRD_AUTH_CLIENT_SECRET
- AUTH_CLIENT_SECRET=
- AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY
- USE_AUTH0=$NETBIRD_USE_AUTH0
- USE_AUTH0=false
- AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES
- AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI
- AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI
- AUTH_REDIRECT_URI=/nb-auth
- AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
- NETBIRD_TOKEN_SOURCE=$NETBIRD_TOKEN_SOURCE
# SSL
- NGINX_SSL_PORT=443
# Letsencrypt
- LETSENCRYPT_DOMAIN=$NETBIRD_LETSENCRYPT_DOMAIN
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
volumes:
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
- LETSENCRYPT_DOMAIN=none
depends_on:
zitadel:
condition: service_healthy
# Signal
signal:
<<: *default
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
depends_on:
- dashboard
networks: [netbird]
volumes:
- $SIGNAL_VOLUMENAME:/var/lib/netbird
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
ports:
- $NETBIRD_SIGNAL_PORT:80
# # port and command for Let's Encrypt validation
# - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
command: [
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
"--log-file", "console"
]
command: ["--log-file", "console"]
# Relay
relay:
<<: *default
image: netbirdio/relay:$NETBIRD_RELAY_TAG
networks: [netbird]
environment:
- NB_LOG_LEVEL=info
- NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_PORT
- NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_ENDPOINT
# todo: change to a secure secret
- NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
ports:
- $NETBIRD_RELAY_PORT:$NETBIRD_RELAY_PORT
- NB_LOG_LEVEL=info
- NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_INTERNAL_PORT
- NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_ENDPOINT
- NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
# Management
management:
<<: *default
image: netbirdio/management:$NETBIRD_MANAGEMENT_TAG
depends_on:
- dashboard
networks: [netbird]
volumes:
- $MGMT_VOLUMENAME:/var/lib/netbird
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
- ./management.json:/etc/netbird/management.json
ports:
- $NETBIRD_MGMT_API_PORT:443 #API port
# # command for Let's Encrypt validation without dashboard container
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
command: [
"--port", "443",
"--port", "80",
"--log-file", "console",
"--log-level", "info",
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
]
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN",
"--idp-sign-key-refresh-enabled"
]
environment:
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN
- NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN
depends_on:
zitadel:
condition: service_healthy
# Coturn
coturn:
<<: *default
image: coturn/coturn:$COTURN_TAG
#domainname: $TURN_DOMAIN # only needed when TLS is enabled
volumes:
- ./turnserver.conf:/etc/turnserver.conf:ro
# - ./privkey.pem:/etc/coturn/private/privkey.pem:ro
# - ./cert.pem:/etc/coturn/certs/cert.pem:ro
network_mode: host
command:
- -c /etc/turnserver.conf
# Zitadel Identity Provider (single container with SQLite)
zitadel:
<<: *default
image: ghcr.io/zitadel/zitadel:$ZITADEL_TAG
networks: [netbird]
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
environment:
- ZITADEL_LOG_LEVEL=info
- ZITADEL_MASTERKEY=$ZITADEL_MASTERKEY
- ZITADEL_EXTERNALSECURE=$ZITADEL_EXTERNALSECURE
- ZITADEL_TLS_ENABLED=false
- ZITADEL_EXTERNALPORT=$ZITADEL_EXTERNALPORT
- ZITADEL_EXTERNALDOMAIN=$NETBIRD_DOMAIN
# SQLite database (no separate PostgreSQL container needed)
- ZITADEL_DATABASE_SQLITE_PATH=/data/zitadel.db
# First instance: Admin user
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_USERNAME=$ZITADEL_ADMIN_USERNAME
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_PASSWORD=$ZITADEL_ADMIN_PASSWORD
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_PASSWORDCHANGEREQUIRED=true
# First instance: Service account for NetBird management
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_USERNAME=netbird-service
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_NAME=NetBird Service Account
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINEKEY_TYPE=1
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_EXPIRATIONDATE=$ZITADEL_PAT_EXPIRATION
volumes:
- netbird-zitadel-data:/data
healthcheck:
test: ["CMD", "/app/zitadel", "ready"]
interval: 10s
timeout: 5s
retries: 5
start_period: 30s
volumes:
$MGMT_VOLUMENAME:
$SIGNAL_VOLUMENAME:
$LETSENCRYPT_VOLUMENAME:
netbird-caddy-data:
netbird-zitadel-data:
networks:
netbird:

View File

@@ -45,18 +45,18 @@
"Engine": "$NETBIRD_STORE_CONFIG_ENGINE"
},
"HttpConfig": {
"Address": "0.0.0.0:$NETBIRD_MGMT_API_PORT",
"Address": "0.0.0.0:80",
"AuthIssuer": "$NETBIRD_AUTH_AUTHORITY",
"AuthAudience": "$NETBIRD_AUTH_AUDIENCE",
"AuthAudience": "$NETBIRD_AUTH_CLIENT_ID",
"AuthKeysLocation": "$NETBIRD_AUTH_JWT_CERTS",
"AuthUserIDClaim": "$NETBIRD_AUTH_USER_ID_CLAIM",
"CertFile":"$NETBIRD_MGMT_API_CERT_FILE",
"CertKey":"$NETBIRD_MGMT_API_CERT_KEY_FILE",
"IdpSignKeyRefreshEnabled": $NETBIRD_MGMT_IDP_SIGNKEY_REFRESH,
"OIDCConfigEndpoint":"$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT"
"CertFile": "",
"CertKey": "",
"IdpSignKeyRefreshEnabled": true,
"OIDCConfigEndpoint": "$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT"
},
"IdpManagerConfig": {
"ManagerType": "$NETBIRD_MGMT_IDP",
"ManagerType": "zitadel",
"ClientConfig": {
"Issuer": "$NETBIRD_AUTH_AUTHORITY",
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
@@ -64,40 +64,28 @@
"ClientSecret": "$NETBIRD_IDP_MGMT_CLIENT_SECRET",
"GrantType": "client_credentials"
},
"ExtraConfig": $NETBIRD_IDP_MGMT_EXTRA_CONFIG,
"Auth0ClientCredentials": null,
"AzureClientCredentials": null,
"KeycloakClientCredentials": null,
"ZitadelClientCredentials": null
},
"ExtraConfig": {
"ManagementEndpoint": "$ZITADEL_MANAGEMENT_ENDPOINT"
}
},
"DeviceAuthorizationFlow": {
"Provider": "$NETBIRD_AUTH_DEVICE_AUTH_PROVIDER",
"Provider": "hosted",
"ProviderConfig": {
"Audience": "$NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE",
"AuthorizationEndpoint": "",
"Domain": "$NETBIRD_AUTH0_DOMAIN",
"ClientID": "$NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID",
"ClientSecret": "",
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT",
"Scope": "$NETBIRD_AUTH_DEVICE_AUTH_SCOPE",
"UseIDToken": $NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN,
"RedirectURLs": null
}
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT",
"Scope": "openid"
}
},
"PKCEAuthorizationFlow": {
"ProviderConfig": {
"Audience": "$NETBIRD_AUTH_PKCE_AUDIENCE",
"ClientID": "$NETBIRD_AUTH_CLIENT_ID",
"ClientSecret": "$NETBIRD_AUTH_CLIENT_SECRET",
"Domain": "",
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"AuthorizationEndpoint": "$NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT",
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
"Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES",
"RedirectURLs": [$NETBIRD_AUTH_PKCE_REDIRECT_URLS],
"UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN,
"DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN,
"LoginFlag": $NETBIRD_AUTH_PKCE_LOGIN_FLAG
"Scope": "openid profile email offline_access",
"RedirectURLs": ["http://localhost:53000/", "http://localhost:54000/"]
}
}
}

View File

@@ -1,117 +1,65 @@
## example file, you can copy this file to setup.env and update its values
##
# NetBird Self-Hosted Setup Configuration
# Copy this file to setup.env and configure the required values
# Image tags
# you can force specific tags for each component; will be set to latest if empty
# -------------------------------------------
# Required: Domain Configuration
# -------------------------------------------
# Your NetBird domain (e.g., netbird.mydomain.com)
NETBIRD_DOMAIN=""
# -------------------------------------------
# Optional: Image Tags
# -------------------------------------------
# Leave empty to use 'latest' for all components
NETBIRD_DASHBOARD_TAG=""
NETBIRD_SIGNAL_TAG=""
NETBIRD_MANAGEMENT_TAG=""
COTURN_TAG=""
NETBIRD_RELAY_TAG=""
# Dashboard domain. e.g. app.mydomain.com
NETBIRD_DOMAIN=""
# Zitadel version (default: v2.64.1)
ZITADEL_TAG=""
# TURN server domain. e.g. turn.mydomain.com
# if not specified it will assume NETBIRD_DOMAIN
# -------------------------------------------
# Optional: TURN Server Configuration
# -------------------------------------------
# TURN server domain (defaults to NETBIRD_DOMAIN)
NETBIRD_TURN_DOMAIN=""
# TURN server public IP address
# required for a connection involving peers in
# the same network as the server and external peers
# usually matches the IP for the domain set in NETBIRD_TURN_DOMAIN
# Required for peers behind NAT to connect
NETBIRD_TURN_EXTERNAL_IP=""
# -------------------------------------------
# OIDC
# e.g., https://example.eu.auth0.com/.well-known/openid-configuration
# Optional: Database Configuration
# -------------------------------------------
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
# The default setting is to transmit the audience to the IDP during authorization. However,
# if your IDP does not have this capability, you can turn this off by setting it to false.
#NETBIRD_DASH_AUTH_USE_AUDIENCE=false
NETBIRD_AUTH_AUDIENCE=""
# e.g. netbird-client
NETBIRD_AUTH_CLIENT_ID=""
# indicates the scopes that will be requested to the IDP
NETBIRD_AUTH_SUPPORTED_SCOPES=""
# NETBIRD_AUTH_CLIENT_SECRET is required only by Google workspace.
# NETBIRD_AUTH_CLIENT_SECRET=""
# if you want to use a custom claim for the user ID instead of 'sub', set it here
# NETBIRD_AUTH_USER_ID_CLAIM=""
# indicates whether to use Auth0 or not: true or false
NETBIRD_USE_AUTH0="false"
# if your IDP provider doesn't support fragmented URIs, configure custom
# redirect and silent redirect URIs, these will be concatenated into your NETBIRD_DOMAIN domain.
# NETBIRD_AUTH_REDIRECT_URI="/peers"
# NETBIRD_AUTH_SILENT_REDIRECT_URI="/add-peers"
# Updates the preference to use id tokens instead of access token on dashboard
# Okta and Gitlab IDPs can benefit from this
# NETBIRD_TOKEN_SOURCE="idToken"
# Store engine: sqlite (default), postgres, or mysql
NETBIRD_STORE_CONFIG_ENGINE=""
# For PostgreSQL:
# NETBIRD_STORE_ENGINE_POSTGRES_DSN="host=<HOST> user=<USER> password=<PASS> dbname=<DB> port=5432"
# For MySQL:
# NETBIRD_STORE_ENGINE_MYSQL_DSN="<user>:<pass>@tcp(127.0.0.1:3306)/<db>"
# -------------------------------------------
# OIDC Device Authorization Flow
# Optional: Extra Settings
# -------------------------------------------
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"
NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID=""
# Some IDPs requires different audience, scopes and to use id token for device authorization flow
# you can customize here:
NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid"
NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN=false
# -------------------------------------------
# OIDC PKCE Authorization Flow
# -------------------------------------------
# Comma separated port numbers. if already in use, PKCE flow will choose an available port from the list as an alternative
# eg. 53000,54000
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS="53000"
# -------------------------------------------
# IDP Management
# -------------------------------------------
# eg. zitadel, auth0, azure, keycloak
NETBIRD_MGMT_IDP="none"
# Some IDPs requires different client id and client secret for management api
NETBIRD_IDP_MGMT_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
NETBIRD_IDP_MGMT_CLIENT_SECRET=""
# Required when setting up with Keycloak "https://<YOUR_KEYCLOAK_HOST_AND_PORT>/admin/realms/netbird"
# NETBIRD_IDP_MGMT_EXTRA_ADMIN_ENDPOINT=
# With some IDPs may be needed enabling automatic refresh of signing keys on expire
# NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=false
# NETBIRD_IDP_MGMT_EXTRA_ variables. See https://docs.netbird.io/selfhosted/identity-providers for more information about your IDP of choice.
# -------------------------------------------
# Letsencrypt
# -------------------------------------------
# Disable letsencrypt
# if disabled, cannot use HTTPS anymore and requires setting up a reverse-proxy to do it instead
NETBIRD_DISABLE_LETSENCRYPT=false
# e.g. hello@mydomain.com
NETBIRD_LETSENCRYPT_EMAIL=""
# -------------------------------------------
# Extra settings
# -------------------------------------------
# Disable anonymous metrics collection, see more information at https://netbird.io/docs/FAQ/metrics-collection
# Disable anonymous metrics (default: false)
NETBIRD_DISABLE_ANONYMOUS_METRICS=false
# DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted
# DNS domain for peer resolution (default: netbird.selfhosted)
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted
# Disable default all-to-all policy for new accounts
# Disable default all-to-all policy for new accounts (default: false)
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=false
# -------------------------------------------
# Relay settings
# Advanced: Zitadel Client IDs
# -------------------------------------------
# Relay server domain. e.g. relay.mydomain.com
# if not specified it will assume NETBIRD_DOMAIN
NETBIRD_RELAY_DOMAIN=""
# Relay server connection port. If none is supplied
# it will default to 33080
# should be updated to match TLS-port of reverse proxy when netbird is running behind reverse proxy
NETBIRD_RELAY_PORT=""
# Management API connecting port. If none is supplied
# it will default to 33073
# should be updated to match TLS-port of reverse proxy when netbird is running behind reverse proxy
NETBIRD_MGMT_API_PORT=""
# Signal service connecting port. If none is supplied
# it will default to 10000
# should be updated to match TLS-port of reverse proxy when netbird is running behind reverse proxy
NETBIRD_SIGNAL_PORT=""
# These are auto-generated by Zitadel on first boot
# Only set these if migrating from an existing Zitadel setup
# NETBIRD_AUTH_CLIENT_ID=""
# NETBIRD_AUTH_CLIENT_ID_CLI=""
# NETBIRD_IDP_MGMT_CLIENT_ID=""
# NETBIRD_IDP_MGMT_CLIENT_SECRET=""

View File

@@ -587,42 +587,40 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context, store cache
time.Sleep(delay)
}
userData, err := am.idpManager.GetAllAccounts(ctx)
// Get all users from IdP
idpUsers, err := am.idpManager.GetAllUsers(ctx)
if err != nil {
return err
}
log.WithContext(ctx).Infof("%d entries received from IdP management", len(userData))
log.WithContext(ctx).Infof("%d users received from IdP management", len(idpUsers))
// If the Identity Provider does not support writing AppMetadata,
// in cases like this, we expect it to return all users in an "unset" field.
// We iterate over the users in the "unset" field, look up their AccountID in our store, and
// update their AppMetadata with the AccountID.
if unsetData, ok := userData[idp.UnsetAccountID]; ok {
for _, user := range unsetData {
accountID, err := am.Store.GetAccountByUser(ctx, user.ID)
if err == nil {
data := userData[accountID.Id]
if data == nil {
data = make([]*idp.UserData, 0, 1)
}
user.AppMetadata.WTAccountID = accountID.Id
userData[accountID.Id] = append(data, user)
}
}
// Create a map for quick lookup of IdP users by ID
idpUserMap := make(map[string]*idp.UserData, len(idpUsers))
for _, user := range idpUsers {
idpUserMap[user.ID] = user
}
// Group IdP users by their account ID from NetBird's database
// NetBird DB is the source of truth for account membership
accountUsers := make(map[string][]*idp.UserData)
for _, idpUser := range idpUsers {
account, err := am.Store.GetAccountByUser(ctx, idpUser.ID)
if err != nil {
// User exists in IdP but not in NetBird - skip
continue
}
accountUsers[account.Id] = append(accountUsers[account.Id], idpUser)
}
delete(userData, idp.UnsetAccountID)
rcvdUsers := 0
for accountID, users := range userData {
for accountID, users := range accountUsers {
rcvdUsers += len(users)
err = am.cacheManager.Set(am.ctx, accountID, users, cacheEntryExpiration())
if err != nil {
return err
}
}
log.WithContext(ctx).Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData))
log.WithContext(ctx).Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(accountUsers))
return nil
}
@@ -742,10 +740,6 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
if err != nil {
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
}
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
return "", err
}
return account.Id, nil
}
return "", err
@@ -757,35 +751,6 @@ func isNil(i idp.Manager) bool {
return i == nil || reflect.ValueOf(i).IsNil()
}
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) {
// user can be nil if it wasn't found (e.g., just created)
user, err := am.lookupUserInCache(ctx, userID, accountID)
if err != nil {
return err
}
if user != nil && user.AppMetadata.WTAccountID == accountID {
// it was already set, so we skip the unnecessary update
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
accountID, userID)
return nil
}
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
if err != nil {
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
}
// refresh cache to reflect the update
_, err = am.refreshCache(ctx, accountID)
if err != nil {
return err
}
}
return nil
}
func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) (any, []cacheStore.Option, error) {
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID)
@@ -797,28 +762,32 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
// nolint:staticcheck
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
// Get users belonging to this account from NetBird's database
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil {
return nil, nil, err
}
userData, err := am.idpManager.GetAccount(ctx, accountIDString)
// Get all users from IdP (we don't have per-account queries anymore)
idpUsers, err := am.idpManager.GetAllUsers(ctx)
if err != nil {
return nil, nil, err
}
log.WithContext(ctx).Debugf("%d entries received from IdP management for account %s", len(userData), accountIDString)
log.WithContext(ctx).Debugf("%d total users in IdP, matching against account %s", len(idpUsers), accountIDString)
dataMap := make(map[string]*idp.UserData, len(userData))
for _, datum := range userData {
dataMap[datum.ID] = datum
// Create a map for quick lookup of IdP users by ID
idpUserMap := make(map[string]*idp.UserData, len(idpUsers))
for _, user := range idpUsers {
idpUserMap[user.ID] = user
}
// Match account users against IdP users
matchedUserData := make([]*idp.UserData, 0)
for _, user := range accountUsers {
if user.IsServiceUser {
continue
}
datum, ok := dataMap[user.Id]
datum, ok := idpUserMap[user.Id]
if !ok {
log.WithContext(ctx).Warnf("user %s not found in IDP", user.Id)
continue
@@ -972,7 +941,7 @@ func (am *DefaultAccountManager) lookupCache(ctx context.Context, accountUsers m
return data, nil
}
// isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count and user invite status
// isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count
func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool {
userDataMap := make(map[string]*idp.UserData, len(data))
for _, datum := range data {
@@ -980,15 +949,10 @@ func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers
}
// the accountUsers ID list of non integration users from store, we check if cache has all of them
// as result of for loop knownUsersCount will have number of users are not presented in the cashed
// as result of for loop knownUsersCount will have number of users are not presented in the cached
knownUsersCount := len(accountUsers)
for user, loggedInOnce := range accountUsers {
if datum, ok := userDataMap[user]; ok {
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
log.WithContext(ctx).Infof("user %s has a pending invite and has logged in once, cache invalid", user)
return false
}
for user := range accountUsers {
if _, ok := userDataMap[user]; ok {
knownUsersCount--
continue
}
@@ -1078,12 +1042,6 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
return err
}
// we should register the account ID to this user's metadata in our IDP manager
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, userAccountID)
if err != nil {
return err
}
return nil
}
@@ -1110,11 +1068,6 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
return "", err
}
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, newAccount.Id)
if err != nil {
return "", err
}
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, newAccount.Id, activity.UserJoined, nil)
return newAccount.Id, nil
@@ -1139,11 +1092,6 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
return "", err
}
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID)
if err != nil {
return "", err
}
if newUser.PendingApproval {
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, map[string]any{"pending_approval": true})
} else {
@@ -1155,34 +1103,30 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
// redeemInvite checks whether user has been invited and redeems the invite
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
// only possible with the enabled IdP manager
if am.idpManager == nil {
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
return nil
}
user, err := am.lookupUserInCache(ctx, userID, accountID)
// Get user from NetBird's database (source of truth for authorization data)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return err
}
if user == nil {
return status.Errorf(status.NotFound, "user %s not found in the IdP", userID)
return status.Errorf(status.NotFound, "user %s not found", userID)
}
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
// Check if user has a pending invite that needs to be redeemed
if user.PendingInvite {
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID)
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
// Our job is to just reload cache.
go func() {
_, err = am.refreshCache(ctx, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID)
return
}
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID)
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
}()
// Update user to mark invite as redeemed
user.PendingInvite = false
err = am.Store.SaveUser(ctx, user)
if err != nil {
log.WithContext(ctx).Warnf("failed to redeem invite for user %s: %v", userID, err)
return err
}
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", userID, accountID)
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
}
return nil

View File

@@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/geolocation"
nbgroups "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
"github.com/netbirdio/netbird/management/server/http/handlers/connectors"
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
"github.com/netbirdio/netbird/management/server/http/handlers/events"
"github.com/netbirdio/netbird/management/server/http/handlers/groups"
@@ -134,6 +135,7 @@ func NewAPIHandler(
dns.AddEndpoints(accountManager, router)
events.AddEndpoints(accountManager, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
connectors.AddEndpoints(accountManager, router)
return rootRouter, nil
}

View File

@@ -0,0 +1,590 @@
package connectors
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// API request/response types
// ConnectorResponse represents an IdP connector in API responses
type ConnectorResponse struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
State string `json:"state"`
Issuer string `json:"issuer,omitempty"`
Servers []string `json:"servers,omitempty"`
}
// OIDCConnectorRequest represents a request to create an OIDC connector
type OIDCConnectorRequest struct {
Name string `json:"name"`
Issuer string `json:"issuer"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
Scopes []string `json:"scopes,omitempty"`
IsAutoCreation bool `json:"is_auto_creation,omitempty"`
IsAutoUpdate bool `json:"is_auto_update,omitempty"`
IsCreationAllowed bool `json:"is_creation_allowed,omitempty"`
IsLinkingAllowed bool `json:"is_linking_allowed,omitempty"`
}
// LDAPConnectorRequest represents a request to create an LDAP connector
type LDAPConnectorRequest struct {
Name string `json:"name"`
Servers []string `json:"servers"`
StartTLS bool `json:"start_tls,omitempty"`
BaseDN string `json:"base_dn"`
BindDN string `json:"bind_dn"`
BindPassword string `json:"bind_password"`
UserBase string `json:"user_base,omitempty"`
UserObjectClass []string `json:"user_object_class,omitempty"`
UserFilters []string `json:"user_filters,omitempty"`
Timeout string `json:"timeout,omitempty"`
Attributes *LDAPAttributesRequest `json:"attributes,omitempty"`
IsAutoCreation bool `json:"is_auto_creation,omitempty"`
IsAutoUpdate bool `json:"is_auto_update,omitempty"`
IsCreationAllowed bool `json:"is_creation_allowed,omitempty"`
IsLinkingAllowed bool `json:"is_linking_allowed,omitempty"`
}
// LDAPAttributesRequest maps LDAP attributes to user fields
type LDAPAttributesRequest struct {
IDAttribute string `json:"id_attribute,omitempty"`
FirstNameAttribute string `json:"first_name_attribute,omitempty"`
LastNameAttribute string `json:"last_name_attribute,omitempty"`
DisplayNameAttribute string `json:"display_name_attribute,omitempty"`
EmailAttribute string `json:"email_attribute,omitempty"`
}
// SAMLConnectorRequest represents a request to create a SAML connector
type SAMLConnectorRequest struct {
Name string `json:"name"`
MetadataXML string `json:"metadata_xml,omitempty"`
MetadataURL string `json:"metadata_url,omitempty"`
Binding string `json:"binding,omitempty"`
WithSignedRequest bool `json:"with_signed_request,omitempty"`
NameIDFormat string `json:"name_id_format,omitempty"`
IsAutoCreation bool `json:"is_auto_creation,omitempty"`
IsAutoUpdate bool `json:"is_auto_update,omitempty"`
IsCreationAllowed bool `json:"is_creation_allowed,omitempty"`
IsLinkingAllowed bool `json:"is_linking_allowed,omitempty"`
}
// handler handles HTTP requests for IdP connectors
type handler struct {
accountManager account.Manager
}
// AddEndpoints registers the connector endpoints to the router
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
h := &handler{accountManager: accountManager}
router.HandleFunc("/connectors", h.listConnectors).Methods("GET", "OPTIONS")
router.HandleFunc("/connectors/{connectorId}", h.getConnector).Methods("GET", "OPTIONS")
router.HandleFunc("/connectors/{connectorId}", h.deleteConnector).Methods("DELETE", "OPTIONS")
router.HandleFunc("/connectors/oidc", h.addOIDCConnector).Methods("POST", "OPTIONS")
router.HandleFunc("/connectors/ldap", h.addLDAPConnector).Methods("POST", "OPTIONS")
router.HandleFunc("/connectors/saml", h.addSAMLConnector).Methods("POST", "OPTIONS")
router.HandleFunc("/connectors/{connectorId}/activate", h.activateConnector).Methods("POST", "OPTIONS")
router.HandleFunc("/connectors/{connectorId}/deactivate", h.deleteConnector).Methods("POST", "OPTIONS")
}
// getConnectorManager retrieves the connector manager from the IdP manager
func (h *handler) getConnectorManager() (idp.ConnectorManager, error) {
idpManager := h.accountManager.GetIdpManager()
if idpManager == nil {
return nil, status.Errorf(status.PreconditionFailed, "IdP manager is not configured")
}
connectorManager, ok := idpManager.(idp.ConnectorManager)
if !ok {
return nil, status.Errorf(status.PreconditionFailed, "IdP manager does not support connector management")
}
return connectorManager, nil
}
// listConnectors returns all configured IdP connectors
func (h *handler) listConnectors(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
// Only admins can manage connectors
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if !user.HasAdminPower() {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
return
}
connectorManager, err := h.getConnectorManager()
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
connectors, err := connectorManager.ListConnectors(r.Context())
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to list connectors: %v", err), w)
return
}
response := make([]*ConnectorResponse, 0, len(connectors))
for _, c := range connectors {
response = append(response, toConnectorResponse(c))
}
util.WriteJSONObject(r.Context(), w, response)
}
// getConnector returns a specific connector by ID
func (h *handler) getConnector(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if !user.HasAdminPower() {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
return
}
vars := mux.Vars(r)
connectorID := vars["connectorId"]
if connectorID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "connector ID is required"), w)
return
}
connectorManager, err := h.getConnectorManager()
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
connector, err := connectorManager.GetConnector(r.Context(), connectorID)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "connector not found: %v", err), w)
return
}
util.WriteJSONObject(r.Context(), w, toConnectorResponse(connector))
}
// deleteConnector removes a connector
func (h *handler) deleteConnector(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if !user.HasAdminPower() {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
return
}
vars := mux.Vars(r)
connectorID := vars["connectorId"]
if connectorID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "connector ID is required"), w)
return
}
connectorManager, err := h.getConnectorManager()
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if err := connectorManager.DeleteConnector(r.Context(), connectorID); err != nil {
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to delete connector: %v", err), w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
// addOIDCConnector creates a new OIDC connector
func (h *handler) addOIDCConnector(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if !user.HasAdminPower() {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
return
}
var req OIDCConnectorRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
return
}
if req.Name == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "name is required"), w)
return
}
if req.Issuer == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "issuer is required"), w)
return
}
if req.ClientID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "client_id is required"), w)
return
}
if req.ClientSecret == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "client_secret is required"), w)
return
}
connectorManager, err := h.getConnectorManager()
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
config := idp.OIDCConnectorConfig{
Name: req.Name,
Issuer: req.Issuer,
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
Scopes: req.Scopes,
IsAutoCreation: req.IsAutoCreation,
IsAutoUpdate: req.IsAutoUpdate,
IsCreationAllowed: req.IsCreationAllowed,
IsLinkingAllowed: req.IsLinkingAllowed,
}
connector, err := connectorManager.AddOIDCConnector(r.Context(), config)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to add OIDC connector: %v", err), w)
return
}
w.WriteHeader(http.StatusCreated)
util.WriteJSONObject(r.Context(), w, toConnectorResponse(connector))
}
// addLDAPConnector creates a new LDAP connector
func (h *handler) addLDAPConnector(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if !user.HasAdminPower() {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
return
}
var req LDAPConnectorRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
return
}
if req.Name == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "name is required"), w)
return
}
if len(req.Servers) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "at least one server is required"), w)
return
}
if req.BaseDN == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "base_dn is required"), w)
return
}
if req.BindDN == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "bind_dn is required"), w)
return
}
if req.BindPassword == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "bind_password is required"), w)
return
}
connectorManager, err := h.getConnectorManager()
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
config := idp.LDAPConnectorConfig{
Name: req.Name,
Servers: req.Servers,
StartTLS: req.StartTLS,
BaseDN: req.BaseDN,
BindDN: req.BindDN,
BindPassword: req.BindPassword,
UserBase: req.UserBase,
UserObjectClass: req.UserObjectClass,
UserFilters: req.UserFilters,
Timeout: req.Timeout,
IsAutoCreation: req.IsAutoCreation,
IsAutoUpdate: req.IsAutoUpdate,
IsCreationAllowed: req.IsCreationAllowed,
IsLinkingAllowed: req.IsLinkingAllowed,
}
if req.Attributes != nil {
config.Attributes = idp.LDAPAttributes{
IDAttribute: req.Attributes.IDAttribute,
FirstNameAttribute: req.Attributes.FirstNameAttribute,
LastNameAttribute: req.Attributes.LastNameAttribute,
DisplayNameAttribute: req.Attributes.DisplayNameAttribute,
EmailAttribute: req.Attributes.EmailAttribute,
}
}
connector, err := connectorManager.AddLDAPConnector(r.Context(), config)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to add LDAP connector: %v", err), w)
return
}
w.WriteHeader(http.StatusCreated)
util.WriteJSONObject(r.Context(), w, toConnectorResponse(connector))
}
// addSAMLConnector creates a new SAML connector
func (h *handler) addSAMLConnector(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if !user.HasAdminPower() {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
return
}
var req SAMLConnectorRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
return
}
if req.Name == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "name is required"), w)
return
}
if req.MetadataXML == "" && req.MetadataURL == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "either metadata_xml or metadata_url is required"), w)
return
}
connectorManager, err := h.getConnectorManager()
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
config := idp.SAMLConnectorConfig{
Name: req.Name,
MetadataXML: req.MetadataXML,
MetadataURL: req.MetadataURL,
Binding: req.Binding,
WithSignedRequest: req.WithSignedRequest,
NameIDFormat: req.NameIDFormat,
IsAutoCreation: req.IsAutoCreation,
IsAutoUpdate: req.IsAutoUpdate,
IsCreationAllowed: req.IsCreationAllowed,
IsLinkingAllowed: req.IsLinkingAllowed,
}
connector, err := connectorManager.AddSAMLConnector(r.Context(), config)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to add SAML connector: %v", err), w)
return
}
w.WriteHeader(http.StatusCreated)
util.WriteJSONObject(r.Context(), w, toConnectorResponse(connector))
}
// activateConnector adds the connector to the login policy
func (h *handler) activateConnector(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if !user.HasAdminPower() {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
return
}
vars := mux.Vars(r)
connectorID := vars["connectorId"]
if connectorID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "connector ID is required"), w)
return
}
connectorManager, err := h.getConnectorManager()
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if err := connectorManager.ActivateConnector(r.Context(), connectorID); err != nil {
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to activate connector: %v", err), w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
// deactivateConnector removes the connector from the login policy
func (h *handler) deactivateConnector(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if !user.HasAdminPower() {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
return
}
vars := mux.Vars(r)
connectorID := vars["connectorId"]
if connectorID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "connector ID is required"), w)
return
}
connectorManager, err := h.getConnectorManager()
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if err := connectorManager.DeactivateConnector(r.Context(), connectorID); err != nil {
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to deactivate connector: %v", err), w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
// toConnectorResponse converts an idp.Connector to a ConnectorResponse
func toConnectorResponse(c *idp.Connector) *ConnectorResponse {
return &ConnectorResponse{
ID: c.ID,
Name: c.Name,
Type: string(c.Type),
State: c.State,
Issuer: c.Issuer,
Servers: c.Servers,
}
}

View File

@@ -1,959 +0,0 @@
package idp
import (
"bytes"
"compress/gzip"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus"
)
// Auth0Manager auth0 manager client instance
type Auth0Manager struct {
authIssuer string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// Auth0ClientConfig auth0 manager client configurations
type Auth0ClientConfig struct {
Audience string
AuthIssuer string
ClientID string
ClientSecret string
GrantType string
}
// auth0JWTRequest payload struct to request a JWT Token
type auth0JWTRequest struct {
Audience string `json:"audience"`
AuthIssuer string `json:"auth_issuer"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
GrantType string `json:"grant_type"`
}
// Auth0Credentials auth0 authentication information
type Auth0Credentials struct {
clientConfig Auth0ClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
jwtToken JWTToken
mux sync.Mutex
appMetrics telemetry.AppMetrics
}
// createUserRequest is a user create request
type createUserRequest struct {
Email string `json:"email"`
Name string `json:"name"`
AppMeta AppMetadata `json:"app_metadata"`
Connection string `json:"connection"`
Password string `json:"password"`
VerifyEmail bool `json:"verify_email"`
}
// userExportJobRequest is a user export request struct
type userExportJobRequest struct {
Format string `json:"format"`
Fields []map[string]string `json:"fields"`
}
// userExportJobResponse is a user export response struct
type userExportJobResponse struct {
Type string `json:"type"`
Status string `json:"status"`
ConnectionID string `json:"connection_id"`
Format string `json:"format"`
Limit int `json:"limit"`
Connection string `json:"connection"`
CreatedAt time.Time `json:"created_at"`
ID string `json:"id"`
}
// userExportJobStatusResponse is a user export status response struct
type userExportJobStatusResponse struct {
Type string `json:"type"`
Status string `json:"status"`
ConnectionID string `json:"connection_id"`
Format string `json:"format"`
Limit int `json:"limit"`
Location string `json:"location"`
Connection string `json:"connection"`
CreatedAt time.Time `json:"created_at"`
ID string `json:"id"`
}
// userVerificationJobRequest is a user verification request struct
type userVerificationJobRequest struct {
UserID string `json:"user_id"`
}
// auth0Profile represents an Auth0 user profile response
type auth0Profile struct {
AccountID string `json:"wt_account_id"`
PendingInvite bool `json:"wt_pending_invite"`
UserID string `json:"user_id"`
Name string `json:"name"`
Email string `json:"email"`
CreatedAt string `json:"created_at"`
LastLogin string `json:"last_login"`
}
// Connections represents a single Auth0 connection
// https://auth0.com/docs/api/management/v2/connections/get-connections
type Connection struct {
Id string `json:"id"`
Name string `json:"name"`
DisplayName string `json:"display_name"`
IsDomainConnection bool `json:"is_domain_connection"`
Realms []string `json:"realms"`
Metadata map[string]string `json:"metadata"`
Options ConnectionOptions `json:"options"`
}
type ConnectionOptions struct {
DomainAliases []string `json:"domain_aliases"`
}
// NewAuth0Manager creates a new instance of the Auth0Manager
func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.AuthIssuer == "" {
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, AuthIssuer is missing")
}
if config.ClientID == "" {
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientID is missing")
}
if config.ClientSecret == "" {
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientSecret is missing")
}
if config.Audience == "" {
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, Audience is missing")
}
if config.GrantType == "" {
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, GrantType is missing")
}
credentials := &Auth0Credentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &Auth0Manager{
authIssuer: config.AuthIssuer,
credentials: credentials,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}, nil
}
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from Auth0
func (c *Auth0Credentials) jwtStillValid() bool {
return !c.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(c.jwtToken.expiresInTime)
}
// requestJWTToken performs request to get jwt token
func (c *Auth0Credentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
var res *http.Response
reqURL := c.clientConfig.AuthIssuer + "/oauth/token"
p, err := c.helper.Marshal(auth0JWTRequest(c.clientConfig))
if err != nil {
return res, err
}
payload := strings.NewReader(string(p))
req, err := http.NewRequest("POST", reqURL, payload)
if err != nil {
return res, err
}
req.Header.Add("content-type", "application/json")
log.WithContext(ctx).Debug("requesting new jwt token for idp manager")
res, err = c.httpClient.Do(req)
if err != nil {
if c.appMetrics != nil {
c.appMetrics.IDPMetrics().CountRequestError()
}
return res, err
}
if res.StatusCode != 200 {
return res, fmt.Errorf("unable to get token, statusCode %d", res.StatusCode)
}
return res, nil
}
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
jwtToken := JWTToken{}
body, err := io.ReadAll(rawBody)
if err != nil {
return jwtToken, err
}
err = c.helper.Unmarshal(body, &jwtToken)
if err != nil {
return jwtToken, err
}
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}
// Exp maps into exp from jwt token
var IssuedAt struct{ Exp int64 }
err = json.Unmarshal(data, &IssuedAt)
if err != nil {
return jwtToken, err
}
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
return jwtToken, nil
}
// Authenticate retrieves access token to use the Auth0 Management API
func (c *Auth0Credentials) Authenticate(ctx context.Context) (JWTToken, error) {
c.mux.Lock()
defer c.mux.Unlock()
if c.appMetrics != nil {
c.appMetrics.IDPMetrics().CountAuthenticate()
}
// If jwtToken has an expires time and we have enough time to do a request return immediately
if c.jwtStillValid() {
return c.jwtToken, nil
}
res, err := c.requestJWTToken(ctx)
if err != nil {
return c.jwtToken, err
}
defer func() {
err = res.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing get jwt token response body: %v", err)
}
}()
jwtToken, err := c.parseRequestJWTResponse(res.Body)
if err != nil {
return c.jwtToken, err
}
c.jwtToken = jwtToken
return c.jwtToken, nil
}
func batchRequestUsersURL(authIssuer, accountID string, page int, perPage int) (string, url.Values, error) {
u, err := url.Parse(authIssuer + "/api/v2/users")
if err != nil {
return "", nil, err
}
q := u.Query()
q.Set("page", strconv.Itoa(page))
q.Set("search_engine", "v3")
q.Set("per_page", strconv.Itoa(perPage))
q.Set("q", "app_metadata.wt_account_id:"+accountID)
u.RawQuery = q.Encode()
return u.String(), q, nil
}
func requestByUserIDURL(authIssuer, userID string) string {
return authIssuer + "/api/v2/users/" + userID
}
// GetAccount returns all the users for a given profile. Calls Auth0 API.
func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
var list []*UserData
// https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations
// auth0 limitation of 1000 users via this endpoint
resultsPerPage := 50
for page := 0; page < 20; page++ {
reqURL, query, err := batchRequestUsersURL(am.authIssuer, accountID, page, resultsPerPage)
if err != nil {
return nil, err
}
req, err := http.NewRequest(http.MethodGet, reqURL, strings.NewReader(query.Encode()))
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
res, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAccount()
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
if res.StatusCode != 200 {
return nil, fmt.Errorf("failed requesting user data from IdP %s", string(body))
}
var batch []UserData
err = json.Unmarshal(body, &batch)
if err != nil {
return nil, err
}
log.WithContext(ctx).Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
err = res.Body.Close()
if err != nil {
return nil, err
}
for user := range batch {
list = append(list, &batch[user])
}
if len(batch) == 0 || len(batch) < resultsPerPage {
log.WithContext(ctx).Debugf("finished loading users for accountID %s", accountID)
return list, nil
}
}
return list, nil
}
// GetUserDataByID requests user data from auth0 via ID
func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
reqURL := requestByUserIDURL(am.authIssuer, userID)
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
res, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserDataByID()
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var userData UserData
err = json.Unmarshal(body, &userData)
if err != nil {
return nil, err
}
defer func() {
err = res.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
}
}()
if res.StatusCode != 200 {
return nil, fmt.Errorf("unable to get UserData, statusCode %d", res.StatusCode)
}
return &userData, nil
}
// UpdateUserAppMetadata updates user app metadata based on userId and metadata map
func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return err
}
reqURL := am.authIssuer + "/api/v2/users/" + userID
data, err := am.helper.Marshal(map[string]any{"app_metadata": appMetadata})
if err != nil {
return err
}
payload := strings.NewReader(string(data))
req, err := http.NewRequest("PATCH", reqURL, payload)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.WithContext(ctx).Debugf("updating IdP metadata for user %s", userID)
res, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
defer func() {
err = res.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
}
}()
if res.StatusCode != 200 {
return fmt.Errorf("unable to update the appMetadata, statusCode %d", res.StatusCode)
}
return nil
}
func buildCreateUserRequestPayload(email, name, accountID, invitedByEmail string) (string, error) {
invite := true
req := &createUserRequest{
Email: email,
Name: name,
AppMeta: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
WTInvitedBy: invitedByEmail,
},
Connection: "Username-Password-Authentication",
Password: GeneratePassword(8, 1, 1, 1),
VerifyEmail: true,
}
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}
func buildUserExportRequest() (string, error) {
req := &userExportJobRequest{}
fields := make([]map[string]string, 0)
for _, field := range []string{"created_at", "last_login", "user_id", "email", "name"} {
fields = append(fields, map[string]string{"name": field})
}
fields = append(fields, map[string]string{
"name": "app_metadata.wt_account_id",
"export_as": "wt_account_id",
})
fields = append(fields, map[string]string{
"name": "app_metadata.wt_pending_invite",
"export_as": "wt_pending_invite",
})
req.Format = "json"
req.Fields = fields
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}
func (am *Auth0Manager) createRequest(
ctx context.Context, method string, endpoint string, body io.Reader,
) (*http.Request, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
reqURL := am.authIssuer + endpoint
req, err := http.NewRequest(method, reqURL, body)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
return req, nil
}
func (am *Auth0Manager) createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
req, err := am.createRequest(ctx, "POST", endpoint, strings.NewReader(payloadStr))
if err != nil {
return nil, err
}
req.Header.Add("content-type", "application/json")
return req, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
payloadString, err := buildUserExportRequest()
if err != nil {
return nil, err
}
exportJobReq, err := am.createPostRequest(ctx, "/api/v2/jobs/users-exports", payloadString)
if err != nil {
return nil, err
}
jobResp, err := am.httpClient.Do(exportJobReq)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer func() {
err = jobResp.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
}
}()
if jobResp.StatusCode != 200 {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to update the appMetadata, statusCode %d", jobResp.StatusCode)
}
var exportJobResp userExportJobResponse
body, err := io.ReadAll(jobResp.Body)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
return nil, err
}
err = am.helper.Unmarshal(body, &exportJobResp)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
return nil, err
}
if exportJobResp.ID == "" {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
}
log.WithContext(ctx).Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
done, downloadLink, err := am.checkExportJobStatus(ctx, exportJobResp.ID)
if err != nil {
log.WithContext(ctx).Debugf("Failed at getting status checks from exportJob; %v", err)
return nil, err
}
if done {
return am.downloadProfileExport(ctx, downloadLink)
}
return nil, fmt.Errorf("failed extracting user profiles from auth0")
}
// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list.
// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with
// the same email but different connections that are considered as separate accounts (e.g., Google and username/password).
func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
body, err := doGetReq(ctx, am.httpClient, reqURL, jwtToken.AccessToken)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserByEmail()
}
userResp := []*UserData{}
err = am.helper.Unmarshal(body, &userResp)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
return nil, err
}
return userResp, nil
}
// CreateUser creates a new user in Auth0 Idp and sends an invite
func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail)
if err != nil {
return nil, err
}
req, err := am.createPostRequest(ctx, "/api/v2/users", payloadString)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer func() {
err = resp.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing create user response body: %v", err)
}
}()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}
var createResp UserData
body, err := io.ReadAll(resp.Body)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
return nil, err
}
err = am.helper.Unmarshal(body, &createResp)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
return nil, err
}
if createResp.ID == "" {
return nil, fmt.Errorf("couldn't create user: response %v", resp)
}
log.WithContext(ctx).Debugf("created user %s in account %s", createResp.ID, accountID)
return &createResp, nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error {
userVerificationReq := userVerificationJobRequest{
UserID: userID,
}
payload, err := am.helper.Marshal(userVerificationReq)
if err != nil {
return err
}
req, err := am.createPostRequest(ctx, "/api/v2/jobs/verification-email", string(payload))
if err != nil {
return err
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer func() {
err = resp.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing invite user response body: %v", err)
}
}()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to invite user, statusCode %d", resp.StatusCode)
}
return nil
}
// DeleteUser from Auth0
func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error {
req, err := am.createRequest(ctx, http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
if err != nil {
return err
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.WithContext(ctx).Debugf("execute delete request: %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer func() {
err = resp.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("close delete request body: %v", err)
}
}()
if resp.StatusCode != 204 {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
return nil
}
// GetAllConnections returns detailed list of all connections filtered by given params.
// Note this method is not part of the IDP Manager interface as this is Auth0 specific.
func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string) ([]Connection, error) {
var connections []Connection
q := make(url.Values)
q.Set("strategy", strings.Join(strategy, ","))
req, err := am.createRequest(ctx, http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
if err != nil {
return connections, err
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.WithContext(ctx).Debugf("execute get connections request: %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return connections, err
}
defer func() {
err = resp.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("close get connections request body: %v", err)
}
}()
if resp.StatusCode != 200 {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return connections, fmt.Errorf("unable to get connections, statusCode %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't read get connections response; %v", err)
return connections, err
}
err = am.helper.Unmarshal(body, &connections)
if err != nil {
log.WithContext(ctx).Debugf("Couldn't unmarshal get connection response; %v", err)
return connections, err
}
return connections, err
}
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
// If the status is "completed", then return the downloadLink
func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string) (bool, string, error) {
ctx, cancel := context.WithTimeout(ctx, 90*time.Second)
defer cancel()
retry := time.NewTicker(10 * time.Second)
for {
select {
case <-ctx.Done():
log.WithContext(ctx).Debugf("Export job status stopped...\n")
return false, "", ctx.Err()
case <-retry.C:
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return false, "", err
}
statusURL := am.authIssuer + "/api/v2/jobs/" + jobID
body, err := doGetReq(ctx, am.httpClient, statusURL, jwtToken.AccessToken)
if err != nil {
return false, "", err
}
var status userExportJobStatusResponse
err = am.helper.Unmarshal(body, &status)
if err != nil {
return false, "", err
}
log.WithContext(ctx).Debugf("current export job status is %v", status.Status)
if status.Status != "completed" {
continue
}
return true, status.Location, nil
}
}
}
// downloadProfileExport downloads user profiles from auth0 batch job
func (am *Auth0Manager) downloadProfileExport(ctx context.Context, location string) (map[string][]*UserData, error) {
body, err := doGetReq(ctx, am.httpClient, location, "")
if err != nil {
return nil, err
}
bodyReader := bytes.NewReader(body)
gzipReader, err := gzip.NewReader(bodyReader)
if err != nil {
return nil, err
}
decoder := json.NewDecoder(gzipReader)
res := make(map[string][]*UserData)
for decoder.More() {
profile := auth0Profile{}
err = decoder.Decode(&profile)
if err != nil {
return nil, err
}
if profile.AccountID != "" {
if _, ok := res[profile.AccountID]; !ok {
res[profile.AccountID] = []*UserData{}
}
res[profile.AccountID] = append(res[profile.AccountID],
&UserData{
ID: profile.UserID,
Name: profile.Name,
Email: profile.Email,
AppMetadata: AppMetadata{
WTAccountID: profile.AccountID,
WTPendingInvite: &profile.PendingInvite,
},
})
}
}
return res, nil
}
// Boilerplate implementation for Get Requests.
func doGetReq(ctx context.Context, client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
if accessToken != "" {
req.Header.Add("authorization", "Bearer "+accessToken)
}
res, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() {
err = res.Body.Close()
if err != nil {
log.WithContext(ctx).Errorf("error while closing body for url %s: %v", url, err)
}
}()
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
if res.StatusCode != 200 {
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
}
return body, nil
}

View File

@@ -1,473 +0,0 @@
package idp
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
)
type mockHTTPClient struct {
code int
resBody string
reqBody string
err error
}
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
if req.Body != nil {
body, err := io.ReadAll(req.Body)
if err == nil {
c.reqBody = string(body)
}
}
return &http.Response{
StatusCode: c.code,
Body: io.NopCloser(strings.NewReader(c.resBody)),
}, c.err
}
type mockJsonParser struct {
jsonParser JsonParser
marshalErrorString string
unmarshalErrorString string
}
func (m *mockJsonParser) Marshal(v interface{}) ([]byte, error) {
if m.marshalErrorString != "" {
return nil, errors.New(m.marshalErrorString)
}
return m.jsonParser.Marshal(v)
}
func (m *mockJsonParser) Unmarshal(data []byte, v interface{}) error {
if m.unmarshalErrorString != "" {
return errors.New(m.unmarshalErrorString)
}
return m.jsonParser.Unmarshal(data, v)
}
type mockAuth0Credentials struct {
jwtToken JWTToken
err error
}
func (mc *mockAuth0Credentials) Authenticate(_ context.Context) (JWTToken, error) {
return mc.jwtToken, mc.err
}
func newTestJWT(t *testing.T, expInt int) string {
t.Helper()
now := time.Now()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iat": now.Unix(),
"exp": now.Add(time.Duration(expInt) * time.Second).Unix(),
})
var hmacSampleSecret []byte
tokenString, err := token.SignedString(hmacSampleSecret)
if err != nil {
t.Fatal(err)
}
return tokenString
}
func TestAuth0_RequestJWTToken(t *testing.T) {
type requestJWTTokenTest struct {
name string
inputCode int
inputResBody string
helper ManagerHelper
expectedFuncExitErrDiff error
expectedCode int
expectedToken string
}
exp := 5
token := newTestJWT(t, exp)
requestJWTTokenTesttCase1 := requestJWTTokenTest{
name: "Get Good JWT Response",
inputCode: 200,
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedCode: 200,
expectedToken: token,
}
requestJWTTokenTestCase2 := requestJWTTokenTest{
name: "Request Bad Status Code",
inputCode: 400,
inputResBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
expectedCode: 200,
expectedToken: "",
}
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputResBody,
code: testCase.inputCode,
}
config := Auth0ClientConfig{}
creds := Auth0Credentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
res, err := creds.requestJWTToken(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
} else {
t.Fatal(err)
}
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
assert.NoError(t, err, "unable to read the response body")
jwtToken := JWTToken{}
err = json.Unmarshal(body, &jwtToken)
assert.NoError(t, err, "unable to parse the json input")
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
})
}
}
func TestAuth0_ParseRequestJWTResponse(t *testing.T) {
type parseRequestJWTResponseTest struct {
name string
inputResBody string
helper ManagerHelper
expectedToken string
expectedExpiresIn int
assertErrFunc func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool
assertErrFuncMessage string
}
exp := 100
token := newTestJWT(t, exp)
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
name: "Parse Good JWT Body",
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedToken: token,
expectedExpiresIn: exp,
assertErrFunc: assert.NoError,
assertErrFuncMessage: "no error was expected",
}
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
name: "Parse Bad json JWT Body",
inputResBody: "",
helper: JsonParser{},
expectedToken: "",
expectedExpiresIn: 0,
assertErrFunc: assert.Error,
assertErrFuncMessage: "json error was expected",
}
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
rawBody := io.NopCloser(strings.NewReader(testCase.inputResBody))
config := Auth0ClientConfig{}
creds := Auth0Credentials{
clientConfig: config,
helper: testCase.helper,
}
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
})
}
}
func TestAuth0_JwtStillValid(t *testing.T) {
type jwtStillValidTest struct {
name string
inputTime time.Time
expectedResult bool
message string
}
jwtStillValidTestCase1 := jwtStillValidTest{
name: "JWT still valid",
inputTime: time.Now().Add(10 * time.Second),
expectedResult: true,
message: "should be true",
}
jwtStillValidTestCase2 := jwtStillValidTest{
name: "JWT is invalid",
inputTime: time.Now(),
expectedResult: false,
message: "should be false",
}
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
config := Auth0ClientConfig{}
creds := Auth0Credentials{
clientConfig: config,
}
creds.jwtToken.expiresInTime = testCase.inputTime
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
})
}
}
func TestAuth0_Authenticate(t *testing.T) {
type authenticateTest struct {
name string
inputCode int
inputResBody string
inputExpireToken time.Time
helper ManagerHelper
expectedFuncExitErrDiff error
expectedCode int
expectedToken string
}
exp := 5
token := newTestJWT(t, exp)
authenticateTestCase1 := authenticateTest{
name: "Get Cached token",
inputExpireToken: time.Now().Add(30 * time.Second),
helper: JsonParser{},
// expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
expectedCode: 200,
expectedToken: "",
}
authenticateTestCase2 := authenticateTest{
name: "Get Good JWT Response",
inputCode: 200,
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedCode: 200,
expectedToken: token,
}
authenticateTestCase3 := authenticateTest{
name: "Get Bad Status Code",
inputCode: 400,
inputResBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
expectedCode: 200,
expectedToken: "",
}
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputResBody,
code: testCase.inputCode,
}
config := Auth0ClientConfig{}
creds := Auth0Credentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
} else {
t.Fatal(err)
}
}
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
})
}
}
func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool
assertErrFuncMessage string
}
exp := 15
token := newTestJWT(t, exp)
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockAuth0Credentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockAuth0Credentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 200,
helper: JsonParser{},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
name: "Update Pending Invite",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":true}}", appMetadata.WTAccountID),
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 200,
helper: JsonParser{},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
config := Auth0ClientConfig{}
var creds ManagerCredentials
if testCase.managerCreds != nil {
creds = testCase.managerCreds
} else {
creds = &Auth0Credentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
}
manager := Auth0Manager{
httpClient: &jwtReqClient,
credentials: creds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata(context.Background(), "1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, jwtReqClient.reqBody, "request body should match")
})
}
}
func TestNewAuth0Manager(t *testing.T) {
type test struct {
name string
inputConfig Auth0ClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
defaultTestConfig := Auth0ClientConfig{
AuthIssuer: "https://abc-auth0.eu.auth0.com",
Audience: "https://abc-auth0.eu.auth0.com/api/v2/",
ClientID: "abcdefg",
ClientSecret: "supersecret",
GrantType: "client_credentials",
}
testCase1 := test{
name: "Good Scenario With Config",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
}
testCase2Config := defaultTestConfig
testCase2Config.ClientID = ""
testCase2 := test{
name: "Missing Configuration",
inputConfig: testCase2Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "shouldn't return error when field empty",
}
testCase3Config := defaultTestConfig
testCase3Config.AuthIssuer = "abc-auth0.eu.auth0.com"
for _, testCase := range []test{testCase1, testCase2} {
t.Run(testCase.name, func(t *testing.T) {
_, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
})
}
}

View File

@@ -1,428 +0,0 @@
package idp
import (
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3"
"github.com/netbirdio/netbird/management/server/telemetry"
)
// AuthentikManager authentik manager client instance.
type AuthentikManager struct {
apiClient *api.APIClient
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// AuthentikClientConfig authentik manager client configurations.
type AuthentikClientConfig struct {
Issuer string
ClientID string
Username string
Password string
TokenEndpoint string
GrantType string
}
// AuthentikCredentials authentik authentication information.
type AuthentikCredentials struct {
clientConfig AuthentikClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
jwtToken JWTToken
mux sync.Mutex
appMetrics telemetry.AppMetrics
}
// NewAuthentikManager creates a new instance of the AuthentikManager.
func NewAuthentikManager(config AuthentikClientConfig,
appMetrics telemetry.AppMetrics) (*AuthentikManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.ClientID == "" {
return nil, fmt.Errorf("authentik IdP configuration is incomplete, clientID is missing")
}
if config.Username == "" {
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Username is missing")
}
if config.Password == "" {
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Password is missing")
}
if config.TokenEndpoint == "" {
return nil, fmt.Errorf("authentik IdP configuration is incomplete, TokenEndpoint is missing")
}
if config.Issuer == "" {
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Issuer is missing")
}
if config.GrantType == "" {
return nil, fmt.Errorf("authentik IdP configuration is incomplete, GrantType is missing")
}
// authentik client configuration
issuerURL, err := url.Parse(config.Issuer)
if err != nil {
return nil, err
}
authentikConfig := api.NewConfiguration()
authentikConfig.HTTPClient = httpClient
authentikConfig.Host = issuerURL.Host
authentikConfig.Scheme = issuerURL.Scheme
credentials := &AuthentikCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &AuthentikManager{
apiClient: api.NewAPIClient(authentikConfig),
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from authentik.
func (ac *AuthentikCredentials) jwtStillValid() bool {
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
}
// requestJWTToken performs request to get jwt token.
func (ac *AuthentikCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
data := url.Values{}
data.Set("client_id", ac.clientConfig.ClientID)
data.Set("username", ac.clientConfig.Username)
data.Set("password", ac.clientConfig.Password)
data.Set("grant_type", ac.clientConfig.GrantType)
data.Set("scope", "goauthentik.io/api")
payload := strings.NewReader(data.Encode())
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
if err != nil {
return nil, err
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.WithContext(ctx).Debug("requesting new jwt token for authentik idp manager")
resp, err := ac.httpClient.Do(req)
if err != nil {
if ac.appMetrics != nil {
ac.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unable to get authentik token, statusCode %d", resp.StatusCode)
}
return resp, nil
}
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
jwtToken := JWTToken{}
body, err := io.ReadAll(rawBody)
if err != nil {
return jwtToken, err
}
err = ac.helper.Unmarshal(body, &jwtToken)
if err != nil {
return jwtToken, err
}
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}
// Exp maps into exp from jwt token
var IssuedAt struct{ Exp int64 }
err = ac.helper.Unmarshal(data, &IssuedAt)
if err != nil {
return jwtToken, err
}
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
return jwtToken, nil
}
// Authenticate retrieves access token to use the authentik management API.
func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
ac.mux.Lock()
defer ac.mux.Unlock()
if ac.appMetrics != nil {
ac.appMetrics.IDPMetrics().CountAuthenticate()
}
// reuse the token without requesting a new one if it is not expired,
// and if expiry time is sufficient time available to make a request.
if ac.jwtStillValid() {
return ac.jwtToken, nil
}
resp, err := ac.requestJWTToken(ctx)
if err != nil {
return ac.jwtToken, err
}
defer resp.Body.Close()
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
if err != nil {
return ac.jwtToken, err
}
ac.jwtToken = jwtToken
return ac.jwtToken, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (am *AuthentikManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from authentik via ID.
func (am *AuthentikManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
ctx, err := am.authenticationContext(ctx)
if err != nil {
return nil, err
}
userPk, err := strconv.ParseInt(userID, 10, 32)
if err != nil {
return nil, err
}
user, resp, err := am.apiClient.CoreApi.CoreUsersRetrieve(ctx, int32(userPk)).Execute()
if err != nil {
return nil, err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserDataByID()
}
if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
}
userData := parseAuthentikUser(*user)
userData.AppMetadata = appMetadata
return userData, nil
}
// GetAccount returns all the users for a given profile.
func (am *AuthentikManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
users, err := am.getAllUsers(ctx)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAccount()
}
for index, user := range users {
user.AppMetadata.WTAccountID = accountID
users[index] = user
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AuthentikManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
users, err := am.getAllUsers(ctx)
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
}
return indexedUsers, nil
}
// getAllUsers returns all users in a Authentik account.
func (am *AuthentikManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
users := make([]*UserData, 0)
page := int32(1)
for {
ctx, err := am.authenticationContext(ctx)
if err != nil {
return nil, err
}
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Page(page).Execute()
if err != nil {
return nil, err
}
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}
for _, user := range userList.Results {
users = append(users, parseAuthentikUser(user))
}
page = int32(userList.GetPagination().Next)
if userList.GetPagination().Next == 0 {
break
}
}
return users, nil
}
// CreateUser creates a new user in authentik Idp and sends an invitation.
func (am *AuthentikManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (am *AuthentikManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
ctx, err := am.authenticationContext(ctx)
if err != nil {
return nil, err
}
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Email(email).Execute()
if err != nil {
return nil, err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserByEmail()
}
if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
}
users := make([]*UserData, 0)
for _, user := range userList.Results {
users = append(users, parseAuthentikUser(user))
}
return users, nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (am *AuthentikManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Authentik
func (am *AuthentikManager) DeleteUser(ctx context.Context, userID string) error {
ctx, err := am.authenticationContext(ctx)
if err != nil {
return err
}
userPk, err := strconv.ParseInt(userID, 10, 32)
if err != nil {
return err
}
resp, err := am.apiClient.CoreApi.CoreUsersDestroy(ctx, int32(userPk)).Execute()
if err != nil {
return err
}
defer resp.Body.Close() // nolint
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountDeleteUser()
}
if resp.StatusCode != http.StatusNoContent {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user %s, statusCode %d", userID, resp.StatusCode)
}
return nil
}
func (am *AuthentikManager) authenticationContext(ctx context.Context) (context.Context, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
value := map[string]api.APIKey{
"authentik": {
Key: jwtToken.AccessToken,
Prefix: jwtToken.TokenType,
},
}
return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil
}
func parseAuthentikUser(user api.User) *UserData {
return &UserData{
Email: *user.Email,
Name: user.Name,
ID: strconv.FormatInt(int64(user.Pk), 10),
}
}

View File

@@ -1,320 +0,0 @@
package idp
import (
"context"
"fmt"
"io"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func TestNewAuthentikManager(t *testing.T) {
type test struct {
name string
inputConfig AuthentikClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
defaultTestConfig := AuthentikClientConfig{
ClientID: "client_id",
Username: "username",
Password: "password",
TokenEndpoint: "https://localhost:8080/application/o/token/",
Issuer: "https://localhost:8080/application/o/netbird/",
GrantType: "client_credentials",
}
testCase1 := test{
name: "Good Configuration",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
}
testCase2Config := defaultTestConfig
testCase2Config.ClientID = ""
testCase2 := test{
name: "Missing ClientID Configuration",
inputConfig: testCase2Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
testCase3Config := defaultTestConfig
testCase3Config.Username = ""
testCase3 := test{
name: "Missing Username Configuration",
inputConfig: testCase3Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
testCase4Config := defaultTestConfig
testCase4Config.Password = ""
testCase4 := test{
name: "Missing Password Configuration",
inputConfig: testCase4Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
testCase5Config := defaultTestConfig
testCase5Config.GrantType = ""
testCase5 := test{
name: "Missing GrantType Configuration",
inputConfig: testCase5Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
testCase6Config := defaultTestConfig
testCase6Config.Issuer = ""
testCase6 := test{
name: "Missing Issuer Configuration",
inputConfig: testCase6Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
t.Run(testCase.name, func(t *testing.T) {
_, err := NewAuthentikManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
})
}
}
func TestAuthentikRequestJWTToken(t *testing.T) {
type requestJWTTokenTest struct {
name string
inputCode int
inputRespBody string
helper ManagerHelper
expectedFuncExitErrDiff error
expectedToken string
}
exp := 5
token := newTestJWT(t, exp)
requestJWTTokenTesttCase1 := requestJWTTokenTest{
name: "Good JWT Response",
inputCode: 200,
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedToken: token,
}
requestJWTTokenTestCase2 := requestJWTTokenTest{
name: "Request Bad Status Code",
inputCode: 400,
inputRespBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get authentik token, statusCode 400"),
expectedToken: "",
}
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputRespBody,
code: testCase.inputCode,
}
config := AuthentikClientConfig{}
creds := AuthentikCredentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
resp, err := creds.requestJWTToken(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
} else {
t.Fatal(err)
}
} else {
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err, "unable to read the response body")
jwtToken := JWTToken{}
err = testCase.helper.Unmarshal(body, &jwtToken)
assert.NoError(t, err, "unable to parse the json input")
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
}
})
}
}
func TestAuthentikParseRequestJWTResponse(t *testing.T) {
type parseRequestJWTResponseTest struct {
name string
inputRespBody string
helper ManagerHelper
expectedToken string
expectedExpiresIn int
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
exp := 100
token := newTestJWT(t, exp)
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
name: "Parse Good JWT Body",
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedToken: token,
expectedExpiresIn: exp,
assertErrFunc: assert.NoError,
assertErrFuncMessage: "no error was expected",
}
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
name: "Parse Bad json JWT Body",
inputRespBody: "",
helper: JsonParser{},
expectedToken: "",
expectedExpiresIn: 0,
assertErrFunc: assert.Error,
assertErrFuncMessage: "json error was expected",
}
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
config := AuthentikClientConfig{}
creds := AuthentikCredentials{
clientConfig: config,
helper: testCase.helper,
}
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
})
}
}
func TestAuthentikJwtStillValid(t *testing.T) {
type jwtStillValidTest struct {
name string
inputTime time.Time
expectedResult bool
message string
}
jwtStillValidTestCase1 := jwtStillValidTest{
name: "JWT still valid",
inputTime: time.Now().Add(10 * time.Second),
expectedResult: true,
message: "should be true",
}
jwtStillValidTestCase2 := jwtStillValidTest{
name: "JWT is invalid",
inputTime: time.Now(),
expectedResult: false,
message: "should be false",
}
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
config := AuthentikClientConfig{}
creds := AuthentikCredentials{
clientConfig: config,
}
creds.jwtToken.expiresInTime = testCase.inputTime
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
})
}
}
func TestAuthentikAuthenticate(t *testing.T) {
type authenticateTest struct {
name string
inputCode int
inputResBody string
inputExpireToken time.Time
helper ManagerHelper
expectedFuncExitErrDiff error
expectedCode int
expectedToken string
}
exp := 5
token := newTestJWT(t, exp)
authenticateTestCase1 := authenticateTest{
name: "Get Cached token",
inputExpireToken: time.Now().Add(30 * time.Second),
helper: JsonParser{},
expectedFuncExitErrDiff: nil,
expectedCode: 200,
expectedToken: "",
}
authenticateTestCase2 := authenticateTest{
name: "Get Good JWT Response",
inputCode: 200,
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedCode: 200,
expectedToken: token,
}
authenticateTestCase3 := authenticateTest{
name: "Get Bad Status Code",
inputCode: 400,
inputResBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get authentik token, statusCode 400"),
expectedCode: 200,
expectedToken: "",
}
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputResBody,
code: testCase.inputCode,
}
config := AuthentikClientConfig{}
creds := AuthentikCredentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
} else {
t.Fatal(err)
}
}
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
})
}
}

View File

@@ -1,459 +0,0 @@
package idp
import (
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const profileFields = "id,displayName,mail,userPrincipalName"
// AzureManager azure manager client instance.
type AzureManager struct {
ClientID string
ObjectID string
GraphAPIEndpoint string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// AzureClientConfig azure manager client configurations.
type AzureClientConfig struct {
ClientID string
ClientSecret string
ObjectID string
GraphAPIEndpoint string
TokenEndpoint string
GrantType string
}
// AzureCredentials azure authentication information.
type AzureCredentials struct {
clientConfig AzureClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
jwtToken JWTToken
mux sync.Mutex
appMetrics telemetry.AppMetrics
}
// azureProfile represents an azure user profile.
type azureProfile map[string]any
// NewAzureManager creates a new instance of the AzureManager.
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.ClientID == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, clientID is missing")
}
if config.ClientSecret == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, ClientSecret is missing")
}
if config.TokenEndpoint == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, TokenEndpoint is missing")
}
if config.GraphAPIEndpoint == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, GraphAPIEndpoint is missing")
}
if config.ObjectID == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, ObjectID is missing")
}
if config.GrantType == "" {
return nil, fmt.Errorf("azure IdP configuration is incomplete, GrantType is missing")
}
credentials := &AzureCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &AzureManager{
ObjectID: config.ObjectID,
ClientID: config.ClientID,
GraphAPIEndpoint: config.GraphAPIEndpoint,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
func (ac *AzureCredentials) jwtStillValid() bool {
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
}
// requestJWTToken performs request to get jwt token.
func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
data := url.Values{}
data.Set("client_id", ac.clientConfig.ClientID)
data.Set("client_secret", ac.clientConfig.ClientSecret)
data.Set("grant_type", ac.clientConfig.GrantType)
parsedURL, err := url.Parse(ac.clientConfig.GraphAPIEndpoint)
if err != nil {
return nil, err
}
// get base url and add "/.default" as scope
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
scopeURL := baseURL + "/.default"
data.Set("scope", scopeURL)
payload := strings.NewReader(data.Encode())
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
if err != nil {
return nil, err
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.WithContext(ctx).Debug("requesting new jwt token for azure idp manager")
resp, err := ac.httpClient.Do(req)
if err != nil {
if ac.appMetrics != nil {
ac.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unable to get azure token, statusCode %d", resp.StatusCode)
}
return resp, nil
}
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
jwtToken := JWTToken{}
body, err := io.ReadAll(rawBody)
if err != nil {
return jwtToken, err
}
err = ac.helper.Unmarshal(body, &jwtToken)
if err != nil {
return jwtToken, err
}
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}
// Exp maps into exp from jwt token
var IssuedAt struct{ Exp int64 }
err = ac.helper.Unmarshal(data, &IssuedAt)
if err != nil {
return jwtToken, err
}
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
return jwtToken, nil
}
// Authenticate retrieves access token to use the azure Management API.
func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
ac.mux.Lock()
defer ac.mux.Unlock()
if ac.appMetrics != nil {
ac.appMetrics.IDPMetrics().CountAuthenticate()
}
// reuse the token without requesting a new one if it is not expired,
// and if expiry time is sufficient time available to make a request.
if ac.jwtStillValid() {
return ac.jwtToken, nil
}
resp, err := ac.requestJWTToken(ctx)
if err != nil {
return ac.jwtToken, err
}
defer resp.Body.Close()
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
if err != nil {
return ac.jwtToken, err
}
ac.jwtToken = jwtToken
return ac.jwtToken, nil
}
// CreateUser creates a new user in azure AD Idp.
func (am *AzureManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserDataByID requests user data from keycloak via ID.
func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
q := url.Values{}
q.Add("$select", profileFields)
body, err := am.get(ctx, "users/"+userID, q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserDataByID()
}
var profile azureProfile
err = am.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
userData := profile.userData()
userData.AppMetadata = appMetadata
return userData, nil
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
q := url.Values{}
q.Add("$select", profileFields)
body, err := am.get(ctx, "users/"+email, q)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetUserByEmail()
}
var profile azureProfile
err = am.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
users = append(users, profile.userData())
return users, nil
}
// GetAccount returns all the users for a given profile.
func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
users, err := am.getAllUsers(ctx)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAccount()
}
for index, user := range users {
user.AppMetadata.WTAccountID = accountID
users[index] = user
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
users, err := am.getAllUsers(ctx)
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountGetAllAccounts()
}
return indexedUsers, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID.
func (am *AzureManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (am *AzureManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Azure.
func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return err
}
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID))
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.WithContext(ctx).Debugf("delete idp user %s", userID)
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountDeleteUser()
}
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
return nil
}
// getAllUsers returns all users in an Azure AD account.
func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
users := make([]*UserData, 0)
q := url.Values{}
q.Add("$select", profileFields)
q.Add("$top", "500")
for nextLink := "users"; nextLink != ""; {
body, err := am.get(ctx, nextLink, q)
if err != nil {
return nil, err
}
var profiles struct {
Value []azureProfile
NextLink string `json:"@odata.nextLink"`
}
err = am.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
for _, profile := range profiles.Value {
users = append(users, profile.userData())
}
nextLink = profiles.NextLink
}
return users, nil
}
// get perform Get requests.
func (am *AzureManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
var reqURL string
if strings.HasPrefix(resource, "https") {
// Already an absolute URL for paging
reqURL = resource
} else {
reqURL = fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
}
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// userData construct user data from keycloak profile.
func (ap azureProfile) userData() *UserData {
id, ok := ap["id"].(string)
if !ok {
id = ""
}
email, ok := ap["userPrincipalName"].(string)
if !ok {
email = ""
}
name, ok := ap["displayName"].(string)
if !ok {
name = ""
}
return &UserData{
Email: email,
Name: name,
ID: id,
}
}

View File

@@ -1,178 +0,0 @@
package idp
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestAzureJwtStillValid(t *testing.T) {
type jwtStillValidTest struct {
name string
inputTime time.Time
expectedResult bool
message string
}
jwtStillValidTestCase1 := jwtStillValidTest{
name: "JWT still valid",
inputTime: time.Now().Add(10 * time.Second),
expectedResult: true,
message: "should be true",
}
jwtStillValidTestCase2 := jwtStillValidTest{
name: "JWT is invalid",
inputTime: time.Now(),
expectedResult: false,
message: "should be false",
}
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
config := AzureClientConfig{}
creds := AzureCredentials{
clientConfig: config,
}
creds.jwtToken.expiresInTime = testCase.inputTime
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
})
}
}
func TestAzureAuthenticate(t *testing.T) {
type authenticateTest struct {
name string
inputCode int
inputResBody string
inputExpireToken time.Time
helper ManagerHelper
expectedFuncExitErrDiff error
expectedCode int
expectedToken string
}
exp := 5
token := newTestJWT(t, exp)
authenticateTestCase1 := authenticateTest{
name: "Get Cached token",
inputExpireToken: time.Now().Add(30 * time.Second),
helper: JsonParser{},
expectedFuncExitErrDiff: nil,
expectedCode: 200,
expectedToken: "",
}
authenticateTestCase2 := authenticateTest{
name: "Get Good JWT Response",
inputCode: 200,
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedCode: 200,
expectedToken: token,
}
authenticateTestCase3 := authenticateTest{
name: "Get Bad Status Code",
inputCode: 400,
inputResBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get azure token, statusCode 400"),
expectedCode: 200,
expectedToken: "",
}
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputResBody,
code: testCase.inputCode,
}
config := AzureClientConfig{}
creds := AzureCredentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
} else {
t.Fatal(err)
}
}
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
})
}
}
func TestAzureProfile(t *testing.T) {
type azureProfileTest struct {
name string
invite bool
inputProfile azureProfile
expectedUserData UserData
}
azureProfileTestCase1 := azureProfileTest{
name: "Good Request",
invite: false,
inputProfile: azureProfile{
"id": "test1",
"displayName": "John Doe",
"userPrincipalName": "test1@test.com",
},
expectedUserData: UserData{
Email: "test1@test.com",
Name: "John Doe",
ID: "test1",
},
}
azureProfileTestCase2 := azureProfileTest{
name: "Missing User ID",
invite: true,
inputProfile: azureProfile{
"displayName": "John Doe",
"userPrincipalName": "test2@test.com",
},
expectedUserData: UserData{
Email: "test2@test.com",
Name: "John Doe",
},
}
azureProfileTestCase3 := azureProfileTest{
name: "Missing User Name",
invite: false,
inputProfile: azureProfile{
"id": "test3",
"userPrincipalName": "test3@test.com",
},
expectedUserData: UserData{
ID: "test3",
Email: "test3@test.com",
},
}
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3} {
t.Run(testCase.name, func(t *testing.T) {
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
userData := testCase.inputProfile.userData()
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
})
}
}

View File

@@ -0,0 +1,457 @@
package idp
import (
"context"
"fmt"
)
// ConnectorType represents the type of external identity provider connector
type ConnectorType string
const (
ConnectorTypeOIDC ConnectorType = "oidc"
ConnectorTypeLDAP ConnectorType = "ldap"
ConnectorTypeSAML ConnectorType = "saml"
)
// Connector represents an external identity provider configured in Zitadel
type Connector struct {
ID string `json:"id"`
Name string `json:"name"`
Type ConnectorType `json:"type"`
State string `json:"state"`
Issuer string `json:"issuer,omitempty"` // for OIDC
Servers []string `json:"servers,omitempty"` // for LDAP
}
// OIDCConnectorConfig contains configuration for adding an OIDC connector
type OIDCConnectorConfig struct {
Name string `json:"name"`
Issuer string `json:"issuer"`
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
Scopes []string `json:"scopes,omitempty"`
IsIDTokenMapping bool `json:"isIdTokenMapping,omitempty"`
IsAutoCreation bool `json:"isAutoCreation,omitempty"`
IsAutoUpdate bool `json:"isAutoUpdate,omitempty"`
IsCreationAllowed bool `json:"isCreationAllowed,omitempty"`
IsLinkingAllowed bool `json:"isLinkingAllowed,omitempty"`
IsAutoAccountLinking bool `json:"isAutoAccountLinking,omitempty"`
AccountLinkingEnabled bool `json:"accountLinkingEnabled,omitempty"`
}
// LDAPConnectorConfig contains configuration for adding an LDAP connector
type LDAPConnectorConfig struct {
Name string `json:"name"`
Servers []string `json:"servers"` // e.g., ["ldap://localhost:389"]
StartTLS bool `json:"startTls,omitempty"`
BaseDN string `json:"baseDn"`
BindDN string `json:"bindDn"`
BindPassword string `json:"bindPassword"`
UserBase string `json:"userBase,omitempty"` // typically "dn"
UserObjectClass []string `json:"userObjectClass,omitempty"` // e.g., ["user", "person"]
UserFilters []string `json:"userFilters,omitempty"` // e.g., ["uid", "email"]
Timeout string `json:"timeout,omitempty"` // e.g., "10s"
Attributes LDAPAttributes `json:"attributes,omitempty"`
IsAutoCreation bool `json:"isAutoCreation,omitempty"`
IsAutoUpdate bool `json:"isAutoUpdate,omitempty"`
IsCreationAllowed bool `json:"isCreationAllowed,omitempty"`
IsLinkingAllowed bool `json:"isLinkingAllowed,omitempty"`
}
// LDAPAttributes maps LDAP attributes to Zitadel user fields
type LDAPAttributes struct {
IDAttribute string `json:"idAttribute,omitempty"`
FirstNameAttribute string `json:"firstNameAttribute,omitempty"`
LastNameAttribute string `json:"lastNameAttribute,omitempty"`
DisplayNameAttribute string `json:"displayNameAttribute,omitempty"`
NickNameAttribute string `json:"nickNameAttribute,omitempty"`
EmailAttribute string `json:"emailAttribute,omitempty"`
EmailVerified string `json:"emailVerified,omitempty"`
PhoneAttribute string `json:"phoneAttribute,omitempty"`
PhoneVerified string `json:"phoneVerified,omitempty"`
AvatarURLAttribute string `json:"avatarUrlAttribute,omitempty"`
ProfileAttribute string `json:"profileAttribute,omitempty"`
}
// SAMLConnectorConfig contains configuration for adding a SAML connector
type SAMLConnectorConfig struct {
Name string `json:"name"`
MetadataXML string `json:"metadataXml,omitempty"`
MetadataURL string `json:"metadataUrl,omitempty"`
Binding string `json:"binding,omitempty"` // "SAML_BINDING_POST" or "SAML_BINDING_REDIRECT"
WithSignedRequest bool `json:"withSignedRequest,omitempty"`
NameIDFormat string `json:"nameIdFormat,omitempty"`
IsAutoCreation bool `json:"isAutoCreation,omitempty"`
IsAutoUpdate bool `json:"isAutoUpdate,omitempty"`
IsCreationAllowed bool `json:"isCreationAllowed,omitempty"`
IsLinkingAllowed bool `json:"isLinkingAllowed,omitempty"`
}
// ConnectorManager defines the interface for managing external IdP connectors
type ConnectorManager interface {
// AddOIDCConnector adds a Generic OIDC identity provider connector
AddOIDCConnector(ctx context.Context, config OIDCConnectorConfig) (*Connector, error)
// AddLDAPConnector adds an LDAP identity provider connector
AddLDAPConnector(ctx context.Context, config LDAPConnectorConfig) (*Connector, error)
// AddSAMLConnector adds a SAML identity provider connector
AddSAMLConnector(ctx context.Context, config SAMLConnectorConfig) (*Connector, error)
// ListConnectors returns all configured identity provider connectors
ListConnectors(ctx context.Context) ([]*Connector, error)
// GetConnector returns a specific connector by ID
GetConnector(ctx context.Context, connectorID string) (*Connector, error)
// DeleteConnector removes an identity provider connector
DeleteConnector(ctx context.Context, connectorID string) error
// ActivateConnector adds the connector to the login policy
ActivateConnector(ctx context.Context, connectorID string) error
// DeactivateConnector removes the connector from the login policy
DeactivateConnector(ctx context.Context, connectorID string) error
}
// zitadelProviderResponse represents the response from creating a provider
type zitadelProviderResponse struct {
ID string `json:"id"`
Details struct {
Sequence string `json:"sequence"`
CreationDate string `json:"creationDate"`
ChangeDate string `json:"changeDate"`
ResourceOwner string `json:"resourceOwner"`
} `json:"details"`
}
// zitadelProviderTemplate represents a provider in the list response
type zitadelProviderTemplate struct {
ID string `json:"id"`
Name string `json:"name"`
State string `json:"state"` // IDP_STATE_ACTIVE, IDP_STATE_INACTIVE
Type string `json:"type"` // IDP_TYPE_OIDC, IDP_TYPE_LDAP, IDP_TYPE_SAML, etc.
Owner string `json:"owner"` // IDP_OWNER_TYPE_ORG, IDP_OWNER_TYPE_SYSTEM
// Type-specific fields
OIDC *struct {
Issuer string `json:"issuer"`
ClientID string `json:"clientId"`
} `json:"oidc,omitempty"`
LDAP *struct {
Servers []string `json:"servers"`
BaseDN string `json:"baseDn"`
} `json:"ldap,omitempty"`
SAML *struct {
MetadataURL string `json:"metadataUrl"`
} `json:"saml,omitempty"`
}
// AddOIDCConnector adds a Generic OIDC identity provider connector to Zitadel
func (zm *ZitadelManager) AddOIDCConnector(ctx context.Context, config OIDCConnectorConfig) (*Connector, error) {
// Set defaults for creation/linking if not specified
if !config.IsCreationAllowed && !config.IsLinkingAllowed {
config.IsCreationAllowed = true
config.IsLinkingAllowed = true
}
payload := map[string]any{
"name": config.Name,
"issuer": config.Issuer,
"clientId": config.ClientID,
"clientSecret": config.ClientSecret,
"isIdTokenMapping": config.IsIDTokenMapping,
"isAutoCreation": config.IsAutoCreation,
"isAutoUpdate": config.IsAutoUpdate,
"isCreationAllowed": config.IsCreationAllowed,
"isLinkingAllowed": config.IsLinkingAllowed,
}
if len(config.Scopes) > 0 {
payload["scopes"] = config.Scopes
}
body, err := zm.helper.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal OIDC connector config: %w", err)
}
respBody, err := zm.post(ctx, "idps/generic_oidc", string(body))
if err != nil {
return nil, fmt.Errorf("add OIDC connector: %w", err)
}
var resp zitadelProviderResponse
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
return nil, fmt.Errorf("unmarshal OIDC connector response: %w", err)
}
return &Connector{
ID: resp.ID,
Name: config.Name,
Type: ConnectorTypeOIDC,
State: "active",
Issuer: config.Issuer,
}, nil
}
// AddLDAPConnector adds an LDAP identity provider connector to Zitadel
func (zm *ZitadelManager) AddLDAPConnector(ctx context.Context, config LDAPConnectorConfig) (*Connector, error) {
// Set defaults
if !config.IsCreationAllowed && !config.IsLinkingAllowed {
config.IsCreationAllowed = true
config.IsLinkingAllowed = true
}
if config.UserBase == "" {
config.UserBase = "dn"
}
if config.Timeout == "" {
config.Timeout = "10s"
}
payload := map[string]any{
"name": config.Name,
"servers": config.Servers,
"startTls": config.StartTLS,
"baseDn": config.BaseDN,
"bindDn": config.BindDN,
"bindPassword": config.BindPassword,
"userBase": config.UserBase,
"timeout": config.Timeout,
"isAutoCreation": config.IsAutoCreation,
"isAutoUpdate": config.IsAutoUpdate,
"isCreationAllowed": config.IsCreationAllowed,
"isLinkingAllowed": config.IsLinkingAllowed,
}
if len(config.UserObjectClass) > 0 {
payload["userObjectClasses"] = config.UserObjectClass
}
if len(config.UserFilters) > 0 {
payload["userFilters"] = config.UserFilters
}
// Add attribute mappings if provided
attrs := make(map[string]string)
if config.Attributes.IDAttribute != "" {
attrs["idAttribute"] = config.Attributes.IDAttribute
}
if config.Attributes.FirstNameAttribute != "" {
attrs["firstNameAttribute"] = config.Attributes.FirstNameAttribute
}
if config.Attributes.LastNameAttribute != "" {
attrs["lastNameAttribute"] = config.Attributes.LastNameAttribute
}
if config.Attributes.DisplayNameAttribute != "" {
attrs["displayNameAttribute"] = config.Attributes.DisplayNameAttribute
}
if config.Attributes.EmailAttribute != "" {
attrs["emailAttribute"] = config.Attributes.EmailAttribute
}
if len(attrs) > 0 {
payload["attributes"] = attrs
}
body, err := zm.helper.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal LDAP connector config: %w", err)
}
respBody, err := zm.post(ctx, "idps/ldap", string(body))
if err != nil {
return nil, fmt.Errorf("add LDAP connector: %w", err)
}
var resp zitadelProviderResponse
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
return nil, fmt.Errorf("unmarshal LDAP connector response: %w", err)
}
return &Connector{
ID: resp.ID,
Name: config.Name,
Type: ConnectorTypeLDAP,
State: "active",
Servers: config.Servers,
}, nil
}
// AddSAMLConnector adds a SAML identity provider connector to Zitadel
func (zm *ZitadelManager) AddSAMLConnector(ctx context.Context, config SAMLConnectorConfig) (*Connector, error) {
// Set defaults
if !config.IsCreationAllowed && !config.IsLinkingAllowed {
config.IsCreationAllowed = true
config.IsLinkingAllowed = true
}
payload := map[string]any{
"name": config.Name,
"isAutoCreation": config.IsAutoCreation,
"isAutoUpdate": config.IsAutoUpdate,
"isCreationAllowed": config.IsCreationAllowed,
"isLinkingAllowed": config.IsLinkingAllowed,
}
if config.MetadataXML != "" {
payload["metadataXml"] = config.MetadataXML
} else if config.MetadataURL != "" {
payload["metadataUrl"] = config.MetadataURL
} else {
return nil, fmt.Errorf("either metadataXml or metadataUrl must be provided")
}
if config.Binding != "" {
payload["binding"] = config.Binding
}
if config.WithSignedRequest {
payload["withSignedRequest"] = config.WithSignedRequest
}
if config.NameIDFormat != "" {
payload["nameIdFormat"] = config.NameIDFormat
}
body, err := zm.helper.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal SAML connector config: %w", err)
}
respBody, err := zm.post(ctx, "idps/saml", string(body))
if err != nil {
return nil, fmt.Errorf("add SAML connector: %w", err)
}
var resp zitadelProviderResponse
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
return nil, fmt.Errorf("unmarshal SAML connector response: %w", err)
}
return &Connector{
ID: resp.ID,
Name: config.Name,
Type: ConnectorTypeSAML,
State: "active",
}, nil
}
// ListConnectors returns all configured identity provider connectors
func (zm *ZitadelManager) ListConnectors(ctx context.Context) ([]*Connector, error) {
// Use the search endpoint to list all providers
respBody, err := zm.post(ctx, "idps/_search", "{}")
if err != nil {
return nil, fmt.Errorf("list connectors: %w", err)
}
var resp struct {
Result []zitadelProviderTemplate `json:"result"`
}
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
return nil, fmt.Errorf("unmarshal connectors response: %w", err)
}
connectors := make([]*Connector, 0, len(resp.Result))
for _, p := range resp.Result {
connector := &Connector{
ID: p.ID,
Name: p.Name,
State: normalizeState(p.State),
Type: normalizeType(p.Type),
}
// Add type-specific fields
if p.OIDC != nil {
connector.Issuer = p.OIDC.Issuer
}
if p.LDAP != nil {
connector.Servers = p.LDAP.Servers
}
connectors = append(connectors, connector)
}
return connectors, nil
}
// GetConnector returns a specific connector by ID
func (zm *ZitadelManager) GetConnector(ctx context.Context, connectorID string) (*Connector, error) {
respBody, err := zm.get(ctx, fmt.Sprintf("idps/%s", connectorID), nil)
if err != nil {
return nil, fmt.Errorf("get connector: %w", err)
}
var resp struct {
IDP zitadelProviderTemplate `json:"idp"`
}
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
return nil, fmt.Errorf("unmarshal connector response: %w", err)
}
connector := &Connector{
ID: resp.IDP.ID,
Name: resp.IDP.Name,
State: normalizeState(resp.IDP.State),
Type: normalizeType(resp.IDP.Type),
}
if resp.IDP.OIDC != nil {
connector.Issuer = resp.IDP.OIDC.Issuer
}
if resp.IDP.LDAP != nil {
connector.Servers = resp.IDP.LDAP.Servers
}
return connector, nil
}
// DeleteConnector removes an identity provider connector
func (zm *ZitadelManager) DeleteConnector(ctx context.Context, connectorID string) error {
if err := zm.delete(ctx, fmt.Sprintf("idps/%s", connectorID)); err != nil {
return fmt.Errorf("delete connector: %w", err)
}
return nil
}
// ActivateConnector adds the connector to the organization's login policy
func (zm *ZitadelManager) ActivateConnector(ctx context.Context, connectorID string) error {
payload := map[string]string{
"idpId": connectorID,
}
body, err := zm.helper.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal activate request: %w", err)
}
_, err = zm.post(ctx, "policies/login/idps", string(body))
if err != nil {
return fmt.Errorf("activate connector: %w", err)
}
return nil
}
// DeactivateConnector removes the connector from the organization's login policy
func (zm *ZitadelManager) DeactivateConnector(ctx context.Context, connectorID string) error {
if err := zm.delete(ctx, fmt.Sprintf("policies/login/idps/%s", connectorID)); err != nil {
return fmt.Errorf("deactivate connector: %w", err)
}
return nil
}
// normalizeState converts Zitadel state to a simple string
func normalizeState(state string) string {
switch state {
case "IDP_STATE_ACTIVE":
return "active"
case "IDP_STATE_INACTIVE":
return "inactive"
default:
return state
}
}
// normalizeType converts Zitadel type to ConnectorType
func normalizeType(idpType string) ConnectorType {
switch idpType {
case "IDP_TYPE_OIDC", "IDP_TYPE_OIDC_GENERIC":
return ConnectorTypeOIDC
case "IDP_TYPE_LDAP":
return ConnectorTypeLDAP
case "IDP_TYPE_SAML":
return ConnectorTypeSAML
default:
return ConnectorType(idpType)
}
}

View File

@@ -0,0 +1,313 @@
package idp
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestZitadelManager_AddOIDCConnector(t *testing.T) {
// Create a mock response for the OIDC connector creation
mockResponse := `{"id": "oidc-123", "details": {"sequence": "1", "creationDate": "2024-01-01T00:00:00Z", "changeDate": "2024-01-01T00:00:00Z", "resourceOwner": "org-1"}}`
mockClient := &mockHTTPClient{
code: http.StatusCreated,
resBody: mockResponse,
}
manager := &ZitadelManager{
managementEndpoint: "https://zitadel.example.com/management/v1",
httpClient: mockClient,
credentials: &mockCredentials{token: "test-token"},
helper: JsonParser{},
}
config := OIDCConnectorConfig{
Name: "Okta",
Issuer: "https://okta.example.com",
ClientID: "client-123",
ClientSecret: "secret-456",
Scopes: []string{"openid", "profile", "email"},
}
connector, err := manager.AddOIDCConnector(context.Background(), config)
require.NoError(t, err)
assert.Equal(t, "oidc-123", connector.ID)
assert.Equal(t, "Okta", connector.Name)
assert.Equal(t, ConnectorTypeOIDC, connector.Type)
assert.Equal(t, "https://okta.example.com", connector.Issuer)
// Verify the request body contains expected fields
var reqBody map[string]any
err = json.Unmarshal([]byte(mockClient.reqBody), &reqBody)
require.NoError(t, err)
assert.Equal(t, "Okta", reqBody["name"])
assert.Equal(t, "https://okta.example.com", reqBody["issuer"])
assert.Equal(t, "client-123", reqBody["clientId"])
}
func TestZitadelManager_AddLDAPConnector(t *testing.T) {
mockResponse := `{"id": "ldap-456", "details": {"sequence": "1", "creationDate": "2024-01-01T00:00:00Z", "changeDate": "2024-01-01T00:00:00Z", "resourceOwner": "org-1"}}`
mockClient := &mockHTTPClient{
code: http.StatusCreated,
resBody: mockResponse,
}
manager := &ZitadelManager{
managementEndpoint: "https://zitadel.example.com/management/v1",
httpClient: mockClient,
credentials: &mockCredentials{token: "test-token"},
helper: JsonParser{},
}
config := LDAPConnectorConfig{
Name: "Corporate LDAP",
Servers: []string{"ldap://ldap.example.com:389"},
BaseDN: "dc=example,dc=com",
BindDN: "cn=admin,dc=example,dc=com",
BindPassword: "admin-password",
Attributes: LDAPAttributes{
IDAttribute: "uid",
EmailAttribute: "mail",
},
}
connector, err := manager.AddLDAPConnector(context.Background(), config)
require.NoError(t, err)
assert.Equal(t, "ldap-456", connector.ID)
assert.Equal(t, "Corporate LDAP", connector.Name)
assert.Equal(t, ConnectorTypeLDAP, connector.Type)
assert.Equal(t, []string{"ldap://ldap.example.com:389"}, connector.Servers)
}
func TestZitadelManager_AddSAMLConnector(t *testing.T) {
mockResponse := `{"id": "saml-789", "details": {"sequence": "1", "creationDate": "2024-01-01T00:00:00Z", "changeDate": "2024-01-01T00:00:00Z", "resourceOwner": "org-1"}}`
mockClient := &mockHTTPClient{
code: http.StatusCreated,
resBody: mockResponse,
}
manager := &ZitadelManager{
managementEndpoint: "https://zitadel.example.com/management/v1",
httpClient: mockClient,
credentials: &mockCredentials{token: "test-token"},
helper: JsonParser{},
}
config := SAMLConnectorConfig{
Name: "Enterprise SAML",
MetadataURL: "https://idp.example.com/metadata.xml",
}
connector, err := manager.AddSAMLConnector(context.Background(), config)
require.NoError(t, err)
assert.Equal(t, "saml-789", connector.ID)
assert.Equal(t, "Enterprise SAML", connector.Name)
assert.Equal(t, ConnectorTypeSAML, connector.Type)
}
func TestZitadelManager_AddSAMLConnector_RequiresMetadata(t *testing.T) {
manager := &ZitadelManager{
managementEndpoint: "https://zitadel.example.com/management/v1",
httpClient: &mockHTTPClient{},
credentials: &mockCredentials{token: "test-token"},
helper: JsonParser{},
}
config := SAMLConnectorConfig{
Name: "Invalid SAML",
// Neither MetadataXML nor MetadataURL provided
}
_, err := manager.AddSAMLConnector(context.Background(), config)
require.Error(t, err)
assert.Contains(t, err.Error(), "metadataXml or metadataUrl must be provided")
}
func TestZitadelManager_ListConnectors(t *testing.T) {
mockResponse := `{
"result": [
{
"id": "oidc-1",
"name": "Google",
"state": "IDP_STATE_ACTIVE",
"type": "IDP_TYPE_OIDC",
"oidc": {"issuer": "https://accounts.google.com", "clientId": "google-client"}
},
{
"id": "ldap-1",
"name": "AD",
"state": "IDP_STATE_INACTIVE",
"type": "IDP_TYPE_LDAP",
"ldap": {"servers": ["ldap://ad.example.com:389"], "baseDn": "dc=example,dc=com"}
}
]
}`
mockClient := &mockHTTPClient{
code: http.StatusOK,
resBody: mockResponse,
}
manager := &ZitadelManager{
managementEndpoint: "https://zitadel.example.com/management/v1",
httpClient: mockClient,
credentials: &mockCredentials{token: "test-token"},
helper: JsonParser{},
}
connectors, err := manager.ListConnectors(context.Background())
require.NoError(t, err)
require.Len(t, connectors, 2)
assert.Equal(t, "oidc-1", connectors[0].ID)
assert.Equal(t, "Google", connectors[0].Name)
assert.Equal(t, "active", connectors[0].State)
assert.Equal(t, ConnectorTypeOIDC, connectors[0].Type)
assert.Equal(t, "https://accounts.google.com", connectors[0].Issuer)
assert.Equal(t, "ldap-1", connectors[1].ID)
assert.Equal(t, "AD", connectors[1].Name)
assert.Equal(t, "inactive", connectors[1].State)
assert.Equal(t, ConnectorTypeLDAP, connectors[1].Type)
assert.Equal(t, []string{"ldap://ad.example.com:389"}, connectors[1].Servers)
}
func TestZitadelManager_GetConnector(t *testing.T) {
mockResponse := `{
"idp": {
"id": "oidc-123",
"name": "Okta",
"state": "IDP_STATE_ACTIVE",
"type": "IDP_TYPE_OIDC",
"oidc": {"issuer": "https://okta.example.com", "clientId": "client-123"}
}
}`
mockClient := &mockHTTPClient{
code: http.StatusOK,
resBody: mockResponse,
}
manager := &ZitadelManager{
managementEndpoint: "https://zitadel.example.com/management/v1",
httpClient: mockClient,
credentials: &mockCredentials{token: "test-token"},
helper: JsonParser{},
}
connector, err := manager.GetConnector(context.Background(), "oidc-123")
require.NoError(t, err)
assert.Equal(t, "oidc-123", connector.ID)
assert.Equal(t, "Okta", connector.Name)
assert.Equal(t, ConnectorTypeOIDC, connector.Type)
assert.Equal(t, "https://okta.example.com", connector.Issuer)
}
func TestZitadelManager_DeleteConnector(t *testing.T) {
mockClient := &mockHTTPClient{
code: http.StatusOK,
resBody: "{}",
}
manager := &ZitadelManager{
managementEndpoint: "https://zitadel.example.com/management/v1",
httpClient: mockClient,
credentials: &mockCredentials{token: "test-token"},
helper: JsonParser{},
}
err := manager.DeleteConnector(context.Background(), "oidc-123")
require.NoError(t, err)
}
func TestZitadelManager_ActivateConnector(t *testing.T) {
mockClient := &mockHTTPClient{
code: http.StatusOK,
resBody: "{}",
}
manager := &ZitadelManager{
managementEndpoint: "https://zitadel.example.com/management/v1",
httpClient: mockClient,
credentials: &mockCredentials{token: "test-token"},
helper: JsonParser{},
}
err := manager.ActivateConnector(context.Background(), "oidc-123")
require.NoError(t, err)
// Verify the request body
var reqBody map[string]string
err = json.Unmarshal([]byte(mockClient.reqBody), &reqBody)
require.NoError(t, err)
assert.Equal(t, "oidc-123", reqBody["idpId"])
}
func TestZitadelManager_DeactivateConnector(t *testing.T) {
mockClient := &mockHTTPClient{
code: http.StatusOK,
resBody: "{}",
}
manager := &ZitadelManager{
managementEndpoint: "https://zitadel.example.com/management/v1",
httpClient: mockClient,
credentials: &mockCredentials{token: "test-token"},
helper: JsonParser{},
}
err := manager.DeactivateConnector(context.Background(), "oidc-123")
require.NoError(t, err)
}
func TestNormalizeState(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"IDP_STATE_ACTIVE", "active"},
{"IDP_STATE_INACTIVE", "inactive"},
{"custom", "custom"},
}
for _, tc := range tests {
t.Run(tc.input, func(t *testing.T) {
assert.Equal(t, tc.expected, normalizeState(tc.input))
})
}
}
func TestNormalizeType(t *testing.T) {
tests := []struct {
input string
expected ConnectorType
}{
{"IDP_TYPE_OIDC", ConnectorTypeOIDC},
{"IDP_TYPE_OIDC_GENERIC", ConnectorTypeOIDC},
{"IDP_TYPE_LDAP", ConnectorTypeLDAP},
{"IDP_TYPE_SAML", ConnectorTypeSAML},
{"CUSTOM", ConnectorType("CUSTOM")},
}
for _, tc := range tests {
t.Run(tc.input, func(t *testing.T) {
assert.Equal(t, tc.expected, normalizeType(tc.input))
})
}
}
// mockCredentials is a mock implementation of ManagerCredentials for testing
type mockCredentials struct {
token string
}
func (m *mockCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
return JWTToken{AccessToken: m.token}, nil
}

View File

@@ -1,263 +0,0 @@
package idp
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/option"
"github.com/netbirdio/netbird/management/server/telemetry"
)
// GoogleWorkspaceManager Google Workspace manager client instance.
type GoogleWorkspaceManager struct {
usersService *admin.UsersService
CustomerID string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// GoogleWorkspaceClientConfig Google Workspace manager client configurations.
type GoogleWorkspaceClientConfig struct {
ServiceAccountKey string
CustomerID string
}
// GoogleWorkspaceCredentials Google Workspace authentication information.
type GoogleWorkspaceCredentials struct {
clientConfig GoogleWorkspaceClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
appMetrics telemetry.AppMetrics
}
func (gc *GoogleWorkspaceCredentials) Authenticate(_ context.Context) (JWTToken, error) {
return JWTToken{}, nil
}
// NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager.
func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.CustomerID == "" {
return nil, fmt.Errorf("google IdP configuration is incomplete, CustomerID is missing")
}
credentials := &GoogleWorkspaceCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
// Create a new Admin SDK Directory service client
adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey)
if err != nil {
return nil, err
}
service, err := admin.NewService(context.Background(),
option.WithScopes(admin.AdminDirectoryUserReadonlyScope),
option.WithCredentials(adminCredentials),
)
if err != nil {
return nil, err
}
return &GoogleWorkspaceManager{
usersService: service.Users,
CustomerID: config.CustomerID,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from Google Workspace via ID.
func (gm *GoogleWorkspaceManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
user, err := gm.usersService.Get(userID).Do()
if err != nil {
return nil, err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountGetUserDataByID()
}
userData := parseGoogleWorkspaceUser(user)
userData.AppMetadata = appMetadata
return userData, nil
}
// GetAccount returns all the users for a given profile.
func (gm *GoogleWorkspaceManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
users, err := gm.getAllUsers()
if err != nil {
return nil, err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountGetAccount()
}
for index, user := range users {
user.AppMetadata.WTAccountID = accountID
users[index] = user
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (gm *GoogleWorkspaceManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
users, err := gm.getAllUsers()
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountGetAllAccounts()
}
return indexedUsers, nil
}
// getAllUsers returns all users in a Google Workspace account filtered by customer ID.
func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) {
users := make([]*UserData, 0)
pageToken := ""
for {
call := gm.usersService.List().Customer(gm.CustomerID).MaxResults(500)
if pageToken != "" {
call.PageToken(pageToken)
}
resp, err := call.Do()
if err != nil {
return nil, err
}
for _, user := range resp.Users {
users = append(users, parseGoogleWorkspaceUser(user))
}
pageToken = resp.NextPageToken
if pageToken == "" {
break
}
}
return users, nil
}
// CreateUser creates a new user in Google Workspace and sends an invitation.
func (gm *GoogleWorkspaceManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (gm *GoogleWorkspaceManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
user, err := gm.usersService.Get(email).Do()
if err != nil {
return nil, err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountGetUserByEmail()
}
users := make([]*UserData, 0)
users = append(users, parseGoogleWorkspaceUser(user))
return users, nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (gm *GoogleWorkspaceManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from GoogleWorkspace.
func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) error {
if err := gm.usersService.Delete(userID).Do(); err != nil {
return err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountDeleteUser()
}
return nil
}
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
// If that fails, it falls back to using the default Google credentials path.
// It returns the retrieved credentials or an error if unsuccessful.
func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) {
log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key")
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
if err != nil {
return nil, fmt.Errorf("failed to decode service account key: %w", err)
}
creds, err := google.CredentialsFromJSON(
context.Background(),
decodeKey,
admin.AdminDirectoryUserReadonlyScope,
)
if err == nil {
// No need to fallback to the default Google credentials path
return creds, nil
}
log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
log.WithContext(ctx).Debug("falling back to default google credentials location")
creds, err = google.FindDefaultCredentials(
context.Background(),
admin.AdminDirectoryUserReadonlyScope,
)
if err != nil {
return nil, err
}
return creds, nil
}
// parseGoogleWorkspaceUser parse google user to UserData.
func parseGoogleWorkspaceUser(user *admin.User) *UserData {
return &UserData{
ID: user.Id,
Email: user.PrimaryEmail,
Name: user.Name.FullName,
}
}

View File

@@ -11,24 +11,25 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
// UnsetAccountID is a special key to map users without an account ID
UnsetAccountID = "unset"
)
// Manager idp manager interface
// Note: NetBird is the single source of truth for authorization data (roles, account membership, invite status).
// The IdP only stores identity information (email, name, credentials).
type Manager interface {
UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error
GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
GetAccount(ctx context.Context, accountId string) ([]*UserData, error)
GetAllAccounts(ctx context.Context) (map[string][]*UserData, error)
CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
// CreateUser creates a new user in the IdP. Returns basic user data (ID, email, name).
CreateUser(ctx context.Context, email, name string) (*UserData, error)
// GetUserDataByID retrieves user identity data from the IdP by user ID.
GetUserDataByID(ctx context.Context, userId string) (*UserData, error)
// GetUserByEmail searches for users by email address.
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
// GetAllUsers returns all users from the IdP for cache warming.
GetAllUsers(ctx context.Context) ([]*UserData, error)
// InviteUserByID resends an invitation to a user who hasn't completed signup.
InviteUserByID(ctx context.Context, userID string) error
// DeleteUser removes a user from the IdP.
DeleteUser(ctx context.Context, userID string) error
}
// ClientConfig defines common client configuration for all IdP manager
// ClientConfig defines common client configuration for the IdP manager
type ClientConfig struct {
Issuer string
TokenEndpoint string
@@ -42,13 +43,10 @@ type ExtraConfig map[string]string
// Config an idp configuration struct to be loaded from management server's config file
type Config struct {
ManagerType string
ClientConfig *ClientConfig
ExtraConfig ExtraConfig
Auth0ClientCredentials *Auth0ClientConfig
AzureClientCredentials *AzureClientConfig
KeycloakClientCredentials *KeycloakClientConfig
ZitadelClientCredentials *ZitadelClientConfig
ManagerType string
ClientConfig *ClientConfig
ExtraConfig ExtraConfig
ZitadelClientCredentials *ZitadelClientConfig
}
// ManagerCredentials interface that authenticates using the credential of each type of idp
@@ -67,11 +65,12 @@ type ManagerHelper interface {
Unmarshal(data []byte, v interface{}) error
}
// UserData represents identity information from the IdP.
// Note: Authorization data (account membership, roles, invite status) is stored in NetBird's DB.
type UserData struct {
Email string `json:"email"`
Name string `json:"name"`
ID string `json:"user_id"`
AppMetadata AppMetadata `json:"app_metadata"`
Email string `json:"email"`
Name string `json:"name"`
ID string `json:"user_id"`
}
func (u *UserData) MarshalBinary() (data []byte, err error) {
@@ -91,15 +90,6 @@ func (u *UserData) Unmarshal(data []byte) (err error) {
return json.Unmarshal(data, &u)
}
// AppMetadata user app metadata to associate with a profile
type AppMetadata struct {
// WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP
// maps to wt_account_id when json.marshal
WTAccountID string `json:"wt_account_id,omitempty"`
WTPendingInvite *bool `json:"wt_pending_invite,omitempty"`
WTInvitedBy string `json:"wt_invited_by_email,omitempty"`
}
// JWTToken a JWT object that holds information of a token
type JWTToken struct {
AccessToken string `json:"access_token"`
@@ -109,7 +99,8 @@ type JWTToken struct {
TokenType string `json:"token_type"`
}
// NewManager returns a new idp manager based on the configuration that it receives
// NewManager returns a new idp manager based on the configuration that it receives.
// Only Zitadel is supported as the IdP manager.
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
if config.ClientConfig != nil {
config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/")
@@ -118,46 +109,6 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
switch strings.ToLower(config.ManagerType) {
case "none", "":
return nil, nil //nolint:nilnil
case "auth0":
auth0ClientConfig := config.Auth0ClientCredentials
if config.ClientConfig != nil {
auth0ClientConfig = &Auth0ClientConfig{
Audience: config.ExtraConfig["Audience"],
AuthIssuer: config.ClientConfig.Issuer,
ClientID: config.ClientConfig.ClientID,
ClientSecret: config.ClientConfig.ClientSecret,
GrantType: config.ClientConfig.GrantType,
}
}
return NewAuth0Manager(*auth0ClientConfig, appMetrics)
case "azure":
azureClientConfig := config.AzureClientCredentials
if config.ClientConfig != nil {
azureClientConfig = &AzureClientConfig{
ClientID: config.ClientConfig.ClientID,
ClientSecret: config.ClientConfig.ClientSecret,
GrantType: config.ClientConfig.GrantType,
TokenEndpoint: config.ClientConfig.TokenEndpoint,
ObjectID: config.ExtraConfig["ObjectId"],
GraphAPIEndpoint: config.ExtraConfig["GraphApiEndpoint"],
}
}
return NewAzureManager(*azureClientConfig, appMetrics)
case "keycloak":
keycloakClientConfig := config.KeycloakClientCredentials
if config.ClientConfig != nil {
keycloakClientConfig = &KeycloakClientConfig{
ClientID: config.ClientConfig.ClientID,
ClientSecret: config.ClientConfig.ClientSecret,
GrantType: config.ClientConfig.GrantType,
TokenEndpoint: config.ClientConfig.TokenEndpoint,
AdminEndpoint: config.ExtraConfig["AdminEndpoint"],
}
}
return NewKeycloakManager(*keycloakClientConfig, appMetrics)
case "zitadel":
zitadelClientConfig := config.ZitadelClientCredentials
if config.ClientConfig != nil {
@@ -172,42 +123,7 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
}
return NewZitadelManager(*zitadelClientConfig, appMetrics)
case "authentik":
authentikConfig := AuthentikClientConfig{
Issuer: config.ClientConfig.Issuer,
ClientID: config.ClientConfig.ClientID,
TokenEndpoint: config.ClientConfig.TokenEndpoint,
GrantType: config.ClientConfig.GrantType,
Username: config.ExtraConfig["Username"],
Password: config.ExtraConfig["Password"],
}
return NewAuthentikManager(authentikConfig, appMetrics)
case "okta":
oktaClientConfig := OktaClientConfig{
Issuer: config.ClientConfig.Issuer,
TokenEndpoint: config.ClientConfig.TokenEndpoint,
GrantType: config.ClientConfig.GrantType,
APIToken: config.ExtraConfig["ApiToken"],
}
return NewOktaManager(oktaClientConfig, appMetrics)
case "google":
googleClientConfig := GoogleWorkspaceClientConfig{
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
CustomerID: config.ExtraConfig["CustomerId"],
}
return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics)
case "jumpcloud":
jumpcloudConfig := JumpCloudClientConfig{
APIToken: config.ExtraConfig["ApiToken"],
}
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
case "pocketid":
pocketidConfig := PocketIdClientConfig{
APIToken: config.ExtraConfig["ApiToken"],
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
}
return NewPocketIdManager(pocketidConfig, appMetrics)
default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
return nil, fmt.Errorf("unsupported IdP manager type: %s (only 'zitadel' is supported)", config.ManagerType)
}
}

View File

@@ -1,257 +0,0 @@
package idp
import (
"context"
"fmt"
"net/http"
"strings"
"time"
v1 "github.com/TheJumpCloud/jcapi-go/v1"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
contentType = "application/json"
accept = "application/json"
)
// JumpCloudManager JumpCloud manager client instance.
type JumpCloudManager struct {
client *v1.APIClient
apiToken string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// JumpCloudClientConfig JumpCloud manager client configurations.
type JumpCloudClientConfig struct {
APIToken string
}
// JumpCloudCredentials JumpCloud authentication information.
type JumpCloudCredentials struct {
clientConfig JumpCloudClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
appMetrics telemetry.AppMetrics
}
// NewJumpCloudManager creates a new instance of the JumpCloudManager.
func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppMetrics) (*JumpCloudManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.APIToken == "" {
return nil, fmt.Errorf("jumpCloud IdP configuration is incomplete, ApiToken is missing")
}
client := v1.NewAPIClient(v1.NewConfiguration())
credentials := &JumpCloudCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &JumpCloudManager{
client: client,
apiToken: config.APIToken,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
// Authenticate retrieves access token to use the JumpCloud user API.
func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error) {
return JWTToken{}, nil
}
func (jm *JumpCloudManager) authenticationContext() context.Context {
return context.WithValue(context.Background(), v1.ContextAPIKey, v1.APIKey{
Key: jm.apiToken,
})
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from JumpCloud via ID.
func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
authCtx := jm.authenticationContext()
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetUserDataByID()
}
userData := parseJumpCloudUser(user)
userData.AppMetadata = appMetadata
return userData, nil
}
// GetAccount returns all the users for a given profile.
func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetAccount()
}
users := make([]*UserData, 0)
for _, user := range userList.Results {
userData := parseJumpCloudUser(user)
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetAllAccounts()
}
indexedUsers := make(map[string][]*UserData)
for _, user := range userList.Results {
userData := parseJumpCloudUser(user)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}
return indexedUsers, nil
}
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
searchFilter := map[string]interface{}{
"searchFilter": map[string]interface{}{
"filter": []string{email},
"fields": []string{"email"},
},
}
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, searchFilter)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetUserByEmail()
}
usersData := make([]*UserData, 0)
for _, user := range userList.Results {
usersData = append(usersData, parseJumpCloudUser(user))
}
return usersData, nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from jumpCloud directory
func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
authCtx := jm.authenticationContext()
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountDeleteUser()
}
return nil
}
// parseJumpCloudUser parse JumpCloud system user returned from API V1 to UserData.
func parseJumpCloudUser(user v1.Systemuserreturn) *UserData {
names := []string{user.Firstname, user.Middlename, user.Lastname}
return &UserData{
Email: user.Email,
Name: strings.Join(names, " "),
ID: user.Id,
}
}

View File

@@ -1,46 +0,0 @@
package idp
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func TestNewJumpCloudManager(t *testing.T) {
type test struct {
name string
inputConfig JumpCloudClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
defaultTestConfig := JumpCloudClientConfig{
APIToken: "test123",
}
testCase1 := test{
name: "Good Configuration",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
}
testCase2Config := defaultTestConfig
testCase2Config.APIToken = ""
testCase2 := test{
name: "Missing APIToken Configuration",
inputConfig: testCase2Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
for _, testCase := range []test{testCase1, testCase2} {
t.Run(testCase.name, func(t *testing.T) {
_, err := NewJumpCloudManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
})
}
}

View File

@@ -1,439 +0,0 @@
package idp
import (
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/telemetry"
)
// KeycloakManager keycloak manager client instance.
type KeycloakManager struct {
adminEndpoint string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// KeycloakClientConfig keycloak manager client configurations.
type KeycloakClientConfig struct {
ClientID string
ClientSecret string
AdminEndpoint string
TokenEndpoint string
GrantType string
}
// KeycloakCredentials keycloak authentication information.
type KeycloakCredentials struct {
clientConfig KeycloakClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
jwtToken JWTToken
mux sync.Mutex
appMetrics telemetry.AppMetrics
}
// keycloakUserAttributes holds additional user data fields.
type keycloakUserAttributes map[string][]string
// keycloakProfile represents a keycloak user profile response.
type keycloakProfile struct {
ID string `json:"id"`
CreatedTimestamp int64 `json:"createdTimestamp"`
Username string `json:"username"`
Email string `json:"email"`
Attributes keycloakUserAttributes `json:"attributes"`
}
// NewKeycloakManager creates a new instance of the KeycloakManager.
func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.ClientID == "" {
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, clientID is missing")
}
if config.ClientSecret == "" {
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, ClientSecret is missing")
}
if config.TokenEndpoint == "" {
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, TokenEndpoint is missing")
}
if config.AdminEndpoint == "" {
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, AdminEndpoint is missing")
}
if config.GrantType == "" {
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, GrantType is missing")
}
credentials := &KeycloakCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &KeycloakManager{
adminEndpoint: config.AdminEndpoint,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from keycloak.
func (kc *KeycloakCredentials) jwtStillValid() bool {
return !kc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(kc.jwtToken.expiresInTime)
}
// requestJWTToken performs request to get jwt token.
func (kc *KeycloakCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
data := url.Values{}
data.Set("client_id", kc.clientConfig.ClientID)
data.Set("client_secret", kc.clientConfig.ClientSecret)
data.Set("grant_type", kc.clientConfig.GrantType)
payload := strings.NewReader(data.Encode())
req, err := http.NewRequest(http.MethodPost, kc.clientConfig.TokenEndpoint, payload)
if err != nil {
return nil, err
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.WithContext(ctx).Debug("requesting new jwt token for keycloak idp manager")
resp, err := kc.httpClient.Do(req)
if err != nil {
if kc.appMetrics != nil {
kc.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unable to get keycloak token, statusCode %d", resp.StatusCode)
}
return resp, nil
}
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
jwtToken := JWTToken{}
body, err := io.ReadAll(rawBody)
if err != nil {
return jwtToken, err
}
err = kc.helper.Unmarshal(body, &jwtToken)
if err != nil {
return jwtToken, err
}
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}
// Exp maps into exp from jwt token
var IssuedAt struct{ Exp int64 }
err = kc.helper.Unmarshal(data, &IssuedAt)
if err != nil {
return jwtToken, err
}
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
return jwtToken, nil
}
// Authenticate retrieves access token to use the keycloak Management API.
func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
kc.mux.Lock()
defer kc.mux.Unlock()
if kc.appMetrics != nil {
kc.appMetrics.IDPMetrics().CountAuthenticate()
}
// reuse the token without requesting a new one if it is not expired,
// and if expiry time is sufficient time available to make a request.
if kc.jwtStillValid() {
return kc.jwtToken, nil
}
resp, err := kc.requestJWTToken(ctx)
if err != nil {
return kc.jwtToken, err
}
defer resp.Body.Close()
jwtToken, err := kc.parseRequestJWTResponse(resp.Body)
if err != nil {
return kc.jwtToken, err
}
kc.jwtToken = jwtToken
return kc.jwtToken, nil
}
// CreateUser creates a new user in keycloak Idp and sends an invite.
func (km *KeycloakManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (km *KeycloakManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
q := url.Values{}
q.Add("email", email)
q.Add("exact", "true")
body, err := km.get(ctx, "users", q)
if err != nil {
return nil, err
}
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountGetUserByEmail()
}
profiles := make([]keycloakProfile, 0)
err = km.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles {
users = append(users, profile.userData())
}
return users, nil
}
// GetUserDataByID requests user data from keycloak via ID.
func (km *KeycloakManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) {
body, err := km.get(ctx, "users/"+userID, nil)
if err != nil {
return nil, err
}
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountGetUserDataByID()
}
var profile keycloakProfile
err = km.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
return profile.userData(), nil
}
// GetAccount returns all the users for a given account profile.
func (km *KeycloakManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
profiles, err := km.fetchAllUserProfiles(ctx)
if err != nil {
return nil, err
}
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountGetAccount()
}
users := make([]*UserData, 0)
for _, profile := range profiles {
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (km *KeycloakManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
profiles, err := km.fetchAllUserProfiles(ctx)
if err != nil {
return nil, err
}
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountGetAllAccounts()
}
indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles {
userData := profile.userData()
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}
return indexedUsers, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (km *KeycloakManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (km *KeycloakManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Keycloak by user ID.
func (km *KeycloakManager) DeleteUser(ctx context.Context, userID string) error {
jwtToken, err := km.credentials.Authenticate(ctx)
if err != nil {
return err
}
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID))
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountDeleteUser()
}
resp, err := km.httpClient.Do(req)
if err != nil {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close() // nolint
// In the docs, they specified 200, but in the endpoints, they return 204
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
return nil
}
func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloakProfile, error) {
totalUsers, err := km.totalUsersCount(ctx)
if err != nil {
return nil, err
}
q := url.Values{}
q.Add("max", fmt.Sprint(*totalUsers))
body, err := km.get(ctx, "users", q)
if err != nil {
return nil, err
}
profiles := make([]keycloakProfile, 0)
err = km.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
return profiles, nil
}
// get perform Get requests.
func (km *KeycloakManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
jwtToken, err := km.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/%s?%s", km.adminEndpoint, resource, q.Encode())
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := km.httpClient.Do(req)
if err != nil {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// totalUsersCount returns the total count of all user created.
// Used when fetching all registered accounts with pagination.
func (km *KeycloakManager) totalUsersCount(ctx context.Context) (*int, error) {
body, err := km.get(ctx, "users/count", nil)
if err != nil {
return nil, err
}
count, err := strconv.Atoi(string(body))
if err != nil {
return nil, err
}
return &count, nil
}
// userData construct user data from keycloak profile.
func (kp keycloakProfile) userData() *UserData {
return &UserData{
Email: kp.Email,
Name: kp.Username,
ID: kp.ID,
}
}

View File

@@ -1,310 +0,0 @@
package idp
import (
"context"
"fmt"
"io"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func TestNewKeycloakManager(t *testing.T) {
type test struct {
name string
inputConfig KeycloakClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
defaultTestConfig := KeycloakClientConfig{
ClientID: "client_id",
ClientSecret: "client_secret",
AdminEndpoint: "https://localhost:8080/auth/admin/realms/test123",
TokenEndpoint: "https://localhost:8080/auth/realms/test123/protocol/openid-connect/token",
GrantType: "client_credentials",
}
testCase1 := test{
name: "Good Configuration",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
}
testCase2Config := defaultTestConfig
testCase2Config.ClientID = ""
testCase2 := test{
name: "Missing ClientID Configuration",
inputConfig: testCase2Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
testCase3Config := defaultTestConfig
testCase3Config.ClientSecret = ""
testCase3 := test{
name: "Missing ClientSecret Configuration",
inputConfig: testCase3Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
testCase4Config := defaultTestConfig
testCase4Config.TokenEndpoint = ""
testCase4 := test{
name: "Missing TokenEndpoint Configuration",
inputConfig: testCase3Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
testCase5Config := defaultTestConfig
testCase5Config.GrantType = ""
testCase5 := test{
name: "Missing GrantType Configuration",
inputConfig: testCase3Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
t.Run(testCase.name, func(t *testing.T) {
_, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
})
}
}
func TestKeycloakRequestJWTToken(t *testing.T) {
type requestJWTTokenTest struct {
name string
inputCode int
inputRespBody string
helper ManagerHelper
expectedFuncExitErrDiff error
expectedToken string
}
exp := 5
token := newTestJWT(t, exp)
requestJWTTokenTesttCase1 := requestJWTTokenTest{
name: "Good JWT Response",
inputCode: 200,
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedToken: token,
}
requestJWTTokenTestCase2 := requestJWTTokenTest{
name: "Request Bad Status Code",
inputCode: 400,
inputRespBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
expectedToken: "",
}
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputRespBody,
code: testCase.inputCode,
}
config := KeycloakClientConfig{}
creds := KeycloakCredentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
resp, err := creds.requestJWTToken(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
} else {
t.Fatal(err)
}
} else {
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err, "unable to read the response body")
jwtToken := JWTToken{}
err = testCase.helper.Unmarshal(body, &jwtToken)
assert.NoError(t, err, "unable to parse the json input")
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
}
})
}
}
func TestKeycloakParseRequestJWTResponse(t *testing.T) {
type parseRequestJWTResponseTest struct {
name string
inputRespBody string
helper ManagerHelper
expectedToken string
expectedExpiresIn int
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
exp := 100
token := newTestJWT(t, exp)
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
name: "Parse Good JWT Body",
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedToken: token,
expectedExpiresIn: exp,
assertErrFunc: assert.NoError,
assertErrFuncMessage: "no error was expected",
}
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
name: "Parse Bad json JWT Body",
inputRespBody: "",
helper: JsonParser{},
expectedToken: "",
expectedExpiresIn: 0,
assertErrFunc: assert.Error,
assertErrFuncMessage: "json error was expected",
}
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
config := KeycloakClientConfig{}
creds := KeycloakCredentials{
clientConfig: config,
helper: testCase.helper,
}
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
})
}
}
func TestKeycloakJwtStillValid(t *testing.T) {
type jwtStillValidTest struct {
name string
inputTime time.Time
expectedResult bool
message string
}
jwtStillValidTestCase1 := jwtStillValidTest{
name: "JWT still valid",
inputTime: time.Now().Add(10 * time.Second),
expectedResult: true,
message: "should be true",
}
jwtStillValidTestCase2 := jwtStillValidTest{
name: "JWT is invalid",
inputTime: time.Now(),
expectedResult: false,
message: "should be false",
}
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
config := KeycloakClientConfig{}
creds := KeycloakCredentials{
clientConfig: config,
}
creds.jwtToken.expiresInTime = testCase.inputTime
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
})
}
}
func TestKeycloakAuthenticate(t *testing.T) {
type authenticateTest struct {
name string
inputCode int
inputResBody string
inputExpireToken time.Time
helper ManagerHelper
expectedFuncExitErrDiff error
expectedCode int
expectedToken string
}
exp := 5
token := newTestJWT(t, exp)
authenticateTestCase1 := authenticateTest{
name: "Get Cached token",
inputExpireToken: time.Now().Add(30 * time.Second),
helper: JsonParser{},
expectedFuncExitErrDiff: nil,
expectedCode: 200,
expectedToken: "",
}
authenticateTestCase2 := authenticateTest{
name: "Get Good JWT Response",
inputCode: 200,
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
helper: JsonParser{},
expectedCode: 200,
expectedToken: token,
}
authenticateTestCase3 := authenticateTest{
name: "Get Bad Status Code",
inputCode: 400,
inputResBody: "{}",
helper: JsonParser{},
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
expectedCode: 200,
expectedToken: "",
}
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
t.Run(testCase.name, func(t *testing.T) {
jwtReqClient := mockHTTPClient{
resBody: testCase.inputResBody,
code: testCase.inputCode,
}
config := KeycloakClientConfig{}
creds := KeycloakCredentials{
clientConfig: config,
httpClient: &jwtReqClient,
helper: testCase.helper,
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
} else {
t.Fatal(err)
}
}
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
})
}
}

View File

@@ -4,52 +4,26 @@ import "context"
// MockIDP is a mock implementation of the IDP interface
type MockIDP struct {
UpdateUserAppMetadataFunc func(ctx context.Context, userId string, appMetadata AppMetadata) error
GetUserDataByIDFunc func(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
GetAccountFunc func(ctx context.Context, accountId string) ([]*UserData, error)
GetAllAccountsFunc func(ctx context.Context) (map[string][]*UserData, error)
CreateUserFunc func(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
InviteUserByIDFunc func(ctx context.Context, userID string) error
DeleteUserFunc func(ctx context.Context, userID string) error
}
// UpdateUserAppMetadata is a mock implementation of the IDP interface UpdateUserAppMetadata method
func (m *MockIDP) UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error {
if m.UpdateUserAppMetadataFunc != nil {
return m.UpdateUserAppMetadataFunc(ctx, userId, appMetadata)
}
return nil
}
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
if m.GetUserDataByIDFunc != nil {
return m.GetUserDataByIDFunc(ctx, userId, appMetadata)
}
return nil, nil
}
// GetAccount is a mock implementation of the IDP interface GetAccount method
func (m *MockIDP) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
if m.GetAccountFunc != nil {
return m.GetAccountFunc(ctx, accountId)
}
return nil, nil
}
// GetAllAccounts is a mock implementation of the IDP interface GetAllAccounts method
func (m *MockIDP) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
if m.GetAllAccountsFunc != nil {
return m.GetAllAccountsFunc(ctx)
}
return nil, nil
CreateUserFunc func(ctx context.Context, email, name string) (*UserData, error)
GetUserDataByIDFunc func(ctx context.Context, userId string) (*UserData, error)
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
GetAllUsersFunc func(ctx context.Context) ([]*UserData, error)
InviteUserByIDFunc func(ctx context.Context, userID string) error
DeleteUserFunc func(ctx context.Context, userID string) error
}
// CreateUser is a mock implementation of the IDP interface CreateUser method
func (m *MockIDP) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
func (m *MockIDP) CreateUser(ctx context.Context, email, name string) (*UserData, error) {
if m.CreateUserFunc != nil {
return m.CreateUserFunc(ctx, email, name, accountID, invitedByEmail)
return m.CreateUserFunc(ctx, email, name)
}
return nil, nil
}
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string) (*UserData, error) {
if m.GetUserDataByIDFunc != nil {
return m.GetUserDataByIDFunc(ctx, userId)
}
return nil, nil
}
@@ -62,6 +36,14 @@ func (m *MockIDP) GetUserByEmail(ctx context.Context, email string) ([]*UserData
return nil, nil
}
// GetAllUsers is a mock implementation of the IDP interface GetAllUsers method
func (m *MockIDP) GetAllUsers(ctx context.Context) ([]*UserData, error) {
if m.GetAllUsersFunc != nil {
return m.GetAllUsersFunc(ctx)
}
return nil, nil
}
// InviteUserByID is a mock implementation of the IDP interface InviteUserByID method
func (m *MockIDP) InviteUserByID(ctx context.Context, userID string) error {
if m.InviteUserByIDFunc != nil {

View File

@@ -1,306 +0,0 @@
package idp
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/okta/okta-sdk-golang/v2/okta"
"github.com/okta/okta-sdk-golang/v2/okta/query"
"github.com/netbirdio/netbird/management/server/telemetry"
)
// OktaManager okta manager client instance.
type OktaManager struct {
client *okta.Client
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// OktaClientConfig okta manager client configurations.
type OktaClientConfig struct {
APIToken string
Issuer string
TokenEndpoint string
GrantType string
}
// OktaCredentials okta authentication information.
type OktaCredentials struct {
clientConfig OktaClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
appMetrics telemetry.AppMetrics
}
// NewOktaManager creates a new instance of the OktaManager.
func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*OktaManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
config.Issuer = baseURL(config.Issuer)
if config.APIToken == "" {
return nil, fmt.Errorf("okta IdP configuration is incomplete, APIToken is missing")
}
if config.Issuer == "" {
return nil, fmt.Errorf("okta IdP configuration is incomplete, Issuer is missing")
}
if config.TokenEndpoint == "" {
return nil, fmt.Errorf("okta IdP configuration is incomplete, TokenEndpoint is missing")
}
if config.GrantType == "" {
return nil, fmt.Errorf("okta IdP configuration is incomplete, GrantType is missing")
}
_, client, err := okta.NewClient(context.Background(),
okta.WithOrgUrl(config.Issuer),
okta.WithToken(config.APIToken),
okta.WithHttpClientPtr(httpClient),
)
if err != nil {
return nil, err
}
credentials := &OktaCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &OktaManager{
client: client,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
// Authenticate retrieves access token to use the okta user API.
func (oc *OktaCredentials) Authenticate(_ context.Context) (JWTToken, error) {
return JWTToken{}, nil
}
// CreateUser creates a new user in okta Idp and sends an invitation.
func (om *OktaManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserDataByID requests user data from keycloak via ID.
func (om *OktaManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
user, resp, err := om.client.User.GetUser(context.Background(), userID)
if err != nil {
return nil, err
}
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountGetUserDataByID()
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
}
userData, err := parseOktaUser(user)
if err != nil {
return nil, err
}
userData.AppMetadata = appMetadata
return userData, nil
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (om *OktaManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
user, resp, err := om.client.User.GetUser(context.Background(), url.QueryEscape(email))
if err != nil {
return nil, err
}
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountGetUserByEmail()
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
}
userData, err := parseOktaUser(user)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
users = append(users, userData)
return users, nil
}
// GetAccount returns all the users for a given profile.
func (om *OktaManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
users, err := om.getAllUsers()
if err != nil {
return nil, err
}
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountGetAccount()
}
for index, user := range users {
user.AppMetadata.WTAccountID = accountID
users[index] = user
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (om *OktaManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
users, err := om.getAllUsers()
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountGetAllAccounts()
}
return indexedUsers, nil
}
// getAllUsers returns all users in an Okta account.
func (om *OktaManager) getAllUsers() ([]*UserData, error) {
qp := query.NewQueryParams(query.WithLimit(200))
userList, resp, err := om.client.User.ListUsers(context.Background(), qp)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}
for resp.HasNextPage() {
paginatedUsers := make([]*okta.User, 0)
resp, err = resp.Next(context.Background(), &paginatedUsers)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}
userList = append(userList, paginatedUsers...)
}
users := make([]*UserData, 0, len(userList))
for _, user := range userList {
userData, err := parseOktaUser(user)
if err != nil {
return nil, err
}
users = append(users, userData)
}
return users, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (om *OktaManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (om *OktaManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Okta
func (om *OktaManager) DeleteUser(_ context.Context, userID string) error {
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
if err != nil {
return err
}
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountDeleteUser()
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
return nil
}
// parseOktaUser parse okta user to UserData.
func parseOktaUser(user *okta.User) (*UserData, error) {
var oktaUser struct {
Email string `json:"email"`
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
}
if user == nil {
return nil, fmt.Errorf("invalid okta user")
}
if user.Profile != nil {
helper := JsonParser{}
buf, err := helper.Marshal(*user.Profile)
if err != nil {
return nil, err
}
err = helper.Unmarshal(buf, &oktaUser)
if err != nil {
return nil, err
}
}
return &UserData{
Email: oktaUser.Email,
Name: strings.Join([]string{oktaUser.FirstName, oktaUser.LastName}, " "),
ID: user.Id,
}, nil
}

View File

@@ -1,65 +0,0 @@
package idp
import (
"testing"
"github.com/okta/okta-sdk-golang/v2/okta"
"github.com/stretchr/testify/assert"
)
func TestParseOktaUser(t *testing.T) {
type parseOktaUserTest struct {
name string
inputProfile *okta.User
expectedUserData *UserData
assertErrFunc assert.ErrorAssertionFunc
}
parseOktaTestCase1 := parseOktaUserTest{
name: "Good Request",
inputProfile: &okta.User{
Id: "123",
Profile: &okta.UserProfile{
"email": "test@example.com",
"firstName": "John",
"lastName": "Doe",
},
},
expectedUserData: &UserData{
Email: "test@example.com",
Name: "John Doe",
ID: "123",
AppMetadata: AppMetadata{
WTAccountID: "456",
},
},
assertErrFunc: assert.NoError,
}
parseOktaTestCase2 := parseOktaUserTest{
name: "Invalid okta user",
inputProfile: nil,
expectedUserData: nil,
assertErrFunc: assert.Error,
}
for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
userData, err := parseOktaUser(testCase.inputProfile)
testCase.assertErrFunc(t, err, testCase.assertErrFunc)
if err == nil {
assert.True(t, userDataEqual(testCase.expectedUserData, userData), "user data should match")
}
})
}
}
// userDataEqual helper function to compare UserData structs for equality.
func userDataEqual(a, b *UserData) bool {
if a.Email != b.Email || a.Name != b.Name || a.ID != b.ID {
return false
}
return true
}

View File

@@ -1,384 +0,0 @@
package idp
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"slices"
"strings"
"time"
"github.com/netbirdio/netbird/management/server/telemetry"
)
type PocketIdManager struct {
managementEndpoint string
apiToken string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
type pocketIdCustomClaimDto struct {
Key string `json:"key"`
Value string `json:"value"`
}
type pocketIdUserDto struct {
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
Disabled bool `json:"disabled"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
FirstName string `json:"firstName"`
ID string `json:"id"`
IsAdmin bool `json:"isAdmin"`
LastName string `json:"lastName"`
LdapID string `json:"ldapId"`
Locale string `json:"locale"`
UserGroups []pocketIdUserGroupDto `json:"userGroups"`
Username string `json:"username"`
}
type pocketIdUserCreateDto struct {
Disabled bool `json:"disabled,omitempty"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
FirstName string `json:"firstName"`
IsAdmin bool `json:"isAdmin,omitempty"`
LastName string `json:"lastName,omitempty"`
Locale string `json:"locale,omitempty"`
Username string `json:"username"`
}
type pocketIdPaginatedUserDto struct {
Data []pocketIdUserDto `json:"data"`
Pagination pocketIdPaginationDto `json:"pagination"`
}
type pocketIdPaginationDto struct {
CurrentPage int `json:"currentPage"`
ItemsPerPage int `json:"itemsPerPage"`
TotalItems int `json:"totalItems"`
TotalPages int `json:"totalPages"`
}
func (p *pocketIdUserDto) userData() *UserData {
return &UserData{
Email: p.Email,
Name: p.DisplayName,
ID: p.ID,
AppMetadata: AppMetadata{},
}
}
type pocketIdUserGroupDto struct {
CreatedAt string `json:"createdAt"`
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
FriendlyName string `json:"friendlyName"`
ID string `json:"id"`
LdapID string `json:"ldapId"`
Name string `json:"name"`
}
func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.ManagementEndpoint == "" {
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing")
}
if config.APIToken == "" {
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing")
}
credentials := &PocketIdCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &PocketIdManager{
managementEndpoint: config.ManagementEndpoint,
apiToken: config.APIToken,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) {
var MethodsWithBody = []string{http.MethodPost, http.MethodPut}
if !slices.Contains(MethodsWithBody, method) && body != "" {
return nil, fmt.Errorf("Body provided to unsupported method: %s", method)
}
reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource)
if query != nil {
reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode())
}
var req *http.Request
var err error
if body != "" {
req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body))
} else {
req, err = http.NewRequestWithContext(ctx, method, reqURL, nil)
}
if err != nil {
return nil, err
}
req.Header.Add("X-API-KEY", p.apiToken)
if body != "" {
req.Header.Add("content-type", "application/json")
req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength))
}
resp, err := p.httpClient.Do(req)
if err != nil {
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// getAllUsersPaginated fetches all users from PocketID API using pagination
func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) {
var allUsers []pocketIdUserDto
currentPage := 1
for {
params := url.Values{}
// Copy existing search parameters
for key, values := range searchParams {
params[key] = values
}
params.Set("pagination[limit]", "100")
params.Set("pagination[page]", fmt.Sprintf("%d", currentPage))
body, err := p.request(ctx, http.MethodGet, "users", &params, "")
if err != nil {
return nil, err
}
var profiles pocketIdPaginatedUserDto
err = p.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
allUsers = append(allUsers, profiles.Data...)
// Check if we've reached the last page
if currentPage >= profiles.Pagination.TotalPages {
break
}
currentPage++
}
return allUsers, nil
}
func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "")
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetUserDataByID()
}
var user pocketIdUserDto
err = p.helper.Unmarshal(body, &user)
if err != nil {
return nil, err
}
userData := user.userData()
userData.AppMetadata = appMetadata
return userData, nil
}
func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
// Get all users using pagination
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetAccount()
}
users := make([]*UserData, 0)
for _, profile := range allUsers {
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountId
users = append(users, userData)
}
return users, nil
}
func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
// Get all users using pagination
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetAllAccounts()
}
indexedUsers := make(map[string][]*UserData)
for _, profile := range allUsers {
userData := profile.userData()
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}
return indexedUsers, nil
}
func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
firstLast := strings.Split(name, " ")
createUser := pocketIdUserCreateDto{
Disabled: false,
DisplayName: name,
Email: email,
FirstName: firstLast[0],
LastName: firstLast[1],
Username: firstLast[0] + "." + firstLast[1],
}
payload, err := p.helper.Marshal(createUser)
if err != nil {
return nil, err
}
body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload))
if err != nil {
return nil, err
}
var newUser pocketIdUserDto
err = p.helper.Unmarshal(body, &newUser)
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountCreateUser()
}
var pending bool = true
ret := &UserData{
Email: email,
Name: name,
ID: newUser.ID,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pending,
WTInvitedBy: invitedByEmail,
},
}
return ret, nil
}
func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
params := url.Values{
// This value a
"search": []string{email},
}
body, err := p.request(ctx, http.MethodGet, "users", &params, "")
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetUserByEmail()
}
var profiles struct{ data []pocketIdUserDto }
err = p.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.data {
users = append(users, profile.userData())
}
return users, nil
}
func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error {
_, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "")
if err != nil {
return err
}
return nil
}
func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error {
_, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "")
if err != nil {
return err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountDeleteUser()
}
return nil
}
var _ Manager = (*PocketIdManager)(nil)
type PocketIdClientConfig struct {
APIToken string
ManagementEndpoint string
}
type PocketIdCredentials struct {
clientConfig PocketIdClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
appMetrics telemetry.AppMetrics
}
var _ ManagerCredentials = (*PocketIdCredentials)(nil)
func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) {
return JWTToken{}, nil
}

View File

@@ -1,137 +0,0 @@
package idp
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func TestNewPocketIdManager(t *testing.T) {
type test struct {
name string
inputConfig PocketIdClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
defaultTestConfig := PocketIdClientConfig{
APIToken: "api_token",
ManagementEndpoint: "http://localhost",
}
tests := []test{
{
name: "Good Configuration",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
},
{
name: "Missing ManagementEndpoint",
inputConfig: PocketIdClientConfig{
APIToken: defaultTestConfig.APIToken,
ManagementEndpoint: "",
},
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
},
{
name: "Missing APIToken",
inputConfig: PocketIdClientConfig{
APIToken: "",
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
},
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
})
}
}
func TestPocketID_GetUserDataByID(t *testing.T) {
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
md := AppMetadata{WTAccountID: "acc1"}
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
require.NoError(t, err)
assert.Equal(t, "u1", got.ID)
assert.Equal(t, "user1@example.com", got.Email)
assert.Equal(t, "User One", got.Name)
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
}
func TestPocketID_GetAccount_WithPagination(t *testing.T) {
// Single page response with two users
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
users, err := mgr.GetAccount(context.Background(), "accX")
require.NoError(t, err)
require.Len(t, users, 2)
assert.Equal(t, "u1", users[0].ID)
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
assert.Equal(t, "u2", users[1].ID)
}
func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
accounts, err := mgr.GetAllAccounts(context.Background())
require.NoError(t, err)
require.Len(t, accounts[UnsetAccountID], 2)
}
func TestPocketID_CreateUser(t *testing.T) {
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
require.NoError(t, err)
assert.Equal(t, "newid", ud.ID)
assert.Equal(t, "new@example.com", ud.Email)
assert.Equal(t, "New User", ud.Name)
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
assert.True(t, *ud.AppMetadata.WTPendingInvite)
}
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
}
func TestPocketID_InviteAndDeleteUser(t *testing.T) {
// Same mock for both calls; returns OK with empty JSON
client := &mockHTTPClient{code: 200, resBody: `{}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
err = mgr.InviteUserByID(context.Background(), "u1")
require.NoError(t, err)
err = mgr.DeleteUser(context.Background(), "u1")
require.NoError(t, err)
}

View File

@@ -0,0 +1,47 @@
package idp
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"net/http"
"testing"
"time"
)
// mockHTTPClient is a mock implementation of ManagerHTTPClient for testing
type mockHTTPClient struct {
code int
resBody string
reqBody string
err error
}
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
if c.err != nil {
return nil, c.err
}
if req.Body != nil {
body, _ := io.ReadAll(req.Body)
c.reqBody = string(body)
}
return &http.Response{
StatusCode: c.code,
Body: io.NopCloser(bytes.NewReader([]byte(c.resBody))),
}, nil
}
// newTestJWT creates a test JWT token with the given expiration time in seconds
func newTestJWT(t *testing.T, expiresIn int) string {
t.Helper()
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
exp := time.Now().Add(time.Duration(expiresIn) * time.Second).Unix()
payload := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf(`{"exp":%d}`, exp)))
signature := base64.RawURLEncoding.EncodeToString([]byte("test-signature"))
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
}

View File

@@ -319,14 +319,19 @@ func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error
}
// CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel.
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
// Note: Authorization data (account membership, invite status) is stored in NetBird's DB, not in the IdP.
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name string) (*UserData, error) {
firstLast := strings.SplitN(name, " ", 2)
lastName := firstLast[0]
if len(firstLast) > 1 {
lastName = firstLast[1]
}
var addUser = map[string]any{
"userName": email,
"profile": map[string]string{
"firstName": firstLast[0],
"lastName": firstLast[0],
"lastName": lastName,
"displayName": name,
},
"email": map[string]any{
@@ -357,18 +362,11 @@ func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID
return nil, err
}
var pending bool = true
ret := &UserData{
return &UserData{
Email: email,
Name: name,
ID: newUser.UserId,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pending,
WTInvitedBy: invitedByEmail,
},
}
return ret, nil
}, nil
}
// GetUserByEmail searches users with a given email.
@@ -413,7 +411,7 @@ func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*
}
// GetUserDataByID requests user data from zitadel via ID.
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string) (*UserData, error) {
body, err := zm.get(ctx, "users/"+userID, nil)
if err != nil {
return nil, err
@@ -429,43 +427,12 @@ func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, ap
return nil, err
}
userData := profile.User.userData()
userData.AppMetadata = appMetadata
return userData, nil
return profile.User.userData(), nil
}
// GetAccount returns all the users for a given profile.
func (zm *ZitadelManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
body, err := zm.post(ctx, "users/_search", "")
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountGetAccount()
}
var profiles struct{ Result []zitadelProfile }
err = zm.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.Result {
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
// GetAllUsers returns all users from the IdP.
// Used for cache warming - NetBird matches these against its own user database.
func (zm *ZitadelManager) GetAllUsers(ctx context.Context) ([]*UserData, error) {
body, err := zm.post(ctx, "users/_search", "")
if err != nil {
return nil, err
@@ -481,19 +448,12 @@ func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*Use
return nil, err
}
indexedUsers := make(map[string][]*UserData)
users := make([]*UserData, 0, len(profiles.Result))
for _, profile := range profiles.Result {
userData := profile.userData()
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
users = append(users, profile.userData())
}
return indexedUsers, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
// Metadata values are base64 encoded.
func (zm *ZitadelManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
return users, nil
}
type inviteUserRequest struct {

View File

@@ -288,16 +288,14 @@ func TestZitadelAuthenticate(t *testing.T) {
}
func TestZitadelProfile(t *testing.T) {
type azureProfileTest struct {
type zitadelProfileTest struct {
name string
invite bool
inputProfile zitadelProfile
expectedUserData UserData
}
azureProfileTestCase1 := azureProfileTest{
name: "User Request",
invite: false,
zitadelProfileTestCase1 := zitadelProfileTest{
name: "User Request",
inputProfile: zitadelProfile{
ID: "test1",
State: "USER_STATE_ACTIVE",
@@ -322,15 +320,11 @@ func TestZitadelProfile(t *testing.T) {
ID: "test1",
Name: "ZITADEL Admin",
Email: "test1@mail.com",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
azureProfileTestCase2 := azureProfileTest{
name: "Service User Request",
invite: true,
zitadelProfileTestCase2 := zitadelProfileTest{
name: "Service User Request",
inputProfile: zitadelProfile{
ID: "test2",
State: "USER_STATE_ACTIVE",
@@ -345,15 +339,11 @@ func TestZitadelProfile(t *testing.T) {
ID: "test2",
Name: "machine",
Email: "machine",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2} {
for _, testCase := range []zitadelProfileTest{zitadelProfileTestCase1, zitadelProfileTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
userData := testCase.inputProfile.userData()
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")

View File

@@ -16,7 +16,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/idp"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -487,7 +486,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
if am.idpManager != nil {
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
userdata, err := am.idpManager.GetUserDataByID(ctx, userID)
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -85,6 +85,8 @@ type User struct {
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
// PendingInvite indicates whether the user has accepted their invite and logged in
PendingInvite bool
// PendingApproval indicates whether the user requires approval before being activated
PendingApproval bool
// LastLogin is the last time the user logged in to IdP
@@ -162,7 +164,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
}
userStatus := UserStatusActive
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
if u.PendingInvite {
userStatus = UserStatusInvited
}
@@ -199,6 +201,7 @@ func (u *User) Copy() *User {
ServiceUserName: u.ServiceUserName,
PATs: pats,
Blocked: u.Blocked,
PendingInvite: u.PendingInvite,
PendingApproval: u.PendingApproval,
LastLogin: u.LastLogin,
CreatedAt: u.CreatedAt,

View File

@@ -117,6 +117,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
Issued: invite.Issued,
IntegrationReference: invite.IntegrationReference,
CreatedAt: time.Now().UTC(),
PendingInvite: true, // User hasn't accepted invite yet
}
if err = am.Store.SaveUser(ctx, newUser); err != nil {
@@ -169,7 +170,7 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
}
return am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviterUser.Email)
return am.idpManager.CreateUser(ctx, invite.Email, invite.Name)
}
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
@@ -285,8 +286,8 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
return status.NewPermissionDeniedError()
}
// check if the user is already registered with this ID
user, err := am.lookupUserInCache(ctx, targetUserID, accountID)
// Get user from NetBird's database (source of truth for authorization data)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
return err
}
@@ -295,18 +296,17 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
return status.Errorf(status.NotFound, "user account %s doesn't exist", targetUserID)
}
// check if user account is already invited and account is not activated
pendingInvite := user.AppMetadata.WTPendingInvite
if pendingInvite == nil || !*pendingInvite {
// Check if user is still pending invite (hasn't activated their account)
if !user.PendingInvite {
return status.Errorf(status.PreconditionFailed, "can't invite a user with an activated NetBird account")
}
err = am.idpManager.InviteUserByID(ctx, user.ID)
err = am.idpManager.InviteUserByID(ctx, targetUserID)
if err != nil {
return err
}
am.StoreEvent(ctx, initiatorUserID, user.ID, accountID, activity.UserInvited, nil)
am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserInvited, nil)
return nil
}
@@ -1011,17 +1011,15 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
func (am *DefaultAccountManager) deleteUserFromIDP(ctx context.Context, targetUserID, accountID string) error {
if am.userDeleteFromIDPEnabled {
log.WithContext(ctx).Debugf("user %s deleted from IdP", targetUserID)
log.WithContext(ctx).Debugf("deleting user %s from IdP", targetUserID)
err := am.idpManager.DeleteUser(ctx, targetUserID)
if err != nil {
return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err)
}
} else {
err := am.idpManager.UpdateUserAppMetadata(ctx, targetUserID, idp.AppMetadata{})
if err != nil {
return fmt.Errorf("failed to remove user %s app metadata in IdP: %s", targetUserID, err)
}
}
// Note: If userDeleteFromIDPEnabled is false, the user remains in IdP but is removed from NetBird.
// This allows the user to re-authenticate if they're re-invited later.
err := am.removeUserFromCache(ctx, accountID, targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("remove user from account (%q) cache failed with error: %v", accountID, err)
@@ -1095,7 +1093,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
if !isNil(am.idpManager) {
// Delete if the user already exists in the IdP. Necessary in cases where a user account
// was created where a user account was provisioned but the user did not sign in
_, err := am.idpManager.GetUserDataByID(ctx, targetUserInfo.ID, idp.AppMetadata{WTAccountID: accountID})
_, err := am.idpManager.GetUserDataByID(ctx, targetUserInfo.ID)
if err == nil {
err = am.deleteUserFromIDP(ctx, targetUserInfo.ID, accountID)
if err != nil {

View File

@@ -563,7 +563,7 @@ func TestUser_InviteNewUser(t *testing.T) {
}
idpMock := idp.MockIDP{
CreateUserFunc: func(_ context.Context, email, name, accountID, invitedByEmail string) (*idp.UserData, error) {
CreateUserFunc: func(_ context.Context, email, name string) (*idp.UserData, error) {
newData := &idp.UserData{
Email: email,
Name: name,
@@ -574,7 +574,7 @@ func TestUser_InviteNewUser(t *testing.T) {
return newData, nil
},
GetAccountFunc: func(_ context.Context, accountId string) ([]*idp.UserData, error) {
GetAllUsersFunc: func(_ context.Context) ([]*idp.UserData, error) {
return mockData, nil
},
}
@@ -1068,7 +1068,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
idpManager: &idp.GoogleWorkspaceManager{}, // empty manager
idpManager: &idp.MockIDP{}, // empty manager
cacheLoading: map[string]chan struct{}{},
permissionsManager: permissionsManager,
}