mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
Compare commits
2 Commits
feature/us
...
feat/integ
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9291e3134b | ||
|
|
eb578146e4 |
352
docs/plans/self-service-auth-plan.mdx
Normal file
352
docs/plans/self-service-auth-plan.mdx
Normal file
@@ -0,0 +1,352 @@
|
||||
# Zitadel Integration Plan
|
||||
|
||||
This plan is split into stages. Each stage is self-contained and can be implemented and tested independently.
|
||||
|
||||
---
|
||||
|
||||
## Stage 1: Infrastructure Setup [COMPLETED]
|
||||
|
||||
**Goal**: Add Zitadel to docker-compose with first-boot initialization (single container, SQLite).
|
||||
|
||||
### Design Principles
|
||||
- **KISS**: Single Zitadel container with SQLite (no separate PostgreSQL or init containers)
|
||||
- **DRY**: Leverage Zitadel's built-in first-instance configuration via environment variables
|
||||
- **Clean**: Credentials generated in configure.sh and printed directly
|
||||
|
||||
### Files Created
|
||||
|
||||
#### `infrastructure_files/Caddyfile.tmpl`
|
||||
Reverse proxy configuration routing:
|
||||
- `/api/*`, `/management.*` → Management server
|
||||
- `/relay*` → Relay
|
||||
- `/signalexchange.*`, `/signal*` → Signal
|
||||
- `/oauth/*`, `/oidc/*`, `/.well-known/*`, `/ui/*` → Zitadel
|
||||
- `/*` → Dashboard
|
||||
|
||||
### Files Modified
|
||||
|
||||
#### `infrastructure_files/docker-compose.yml.tmpl`
|
||||
Added single Zitadel service using SQLite:
|
||||
```yaml
|
||||
zitadel:
|
||||
image: ghcr.io/zitadel/zitadel:$ZITADEL_TAG
|
||||
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
||||
environment:
|
||||
- ZITADEL_DATABASE_SQLITE_PATH=/data/zitadel.db
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_USERNAME=$ZITADEL_ADMIN_USERNAME
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_PASSWORD=$ZITADEL_ADMIN_PASSWORD
|
||||
# ... service account config via env vars
|
||||
```
|
||||
|
||||
#### `infrastructure_files/management.json.tmpl`
|
||||
Updated IdpManagerConfig to default to Zitadel.
|
||||
|
||||
#### `infrastructure_files/base.setup.env`
|
||||
Added Zitadel-specific environment variables:
|
||||
- `ZITADEL_TAG` (default: v4.7.6)
|
||||
- `ZITADEL_MASTERKEY` (auto-generated)
|
||||
- `ZITADEL_ADMIN_USERNAME`, `ZITADEL_ADMIN_PASSWORD` (auto-generated)
|
||||
- `ZITADEL_EXTERNALSECURE`, `ZITADEL_EXTERNALPORT`, `ZITADEL_TLS_MODE`
|
||||
|
||||
#### `infrastructure_files/configure.sh`
|
||||
- Generates Zitadel masterkey if not set
|
||||
- Generates admin credentials if not set
|
||||
- Prints credentials to stdout during configuration
|
||||
|
||||
### Deliverable
|
||||
Running `./configure.sh && cd artifacts && docker compose up -d` starts full stack with Zitadel. Admin credentials printed during configuration.
|
||||
|
||||
---
|
||||
|
||||
## Stage 2: Simplify IdP Integration
|
||||
|
||||
**Goal**: Make NetBird the single source of truth for user authorization data. Zitadel handles authentication only.
|
||||
|
||||
### Design Principles
|
||||
- **Single source of truth**: NetBird DB stores all authorization data (roles, account membership, invite status)
|
||||
- **Clean separation**: Zitadel stores identity only (email, name, password)
|
||||
- **No duplicate data**: Remove `wt_account_id` and `wt_pending_invite` from IdP metadata
|
||||
|
||||
### Files to Modify
|
||||
|
||||
#### `management/server/types/user.go`
|
||||
Add `PendingInvite` field to User struct:
|
||||
```go
|
||||
type User struct {
|
||||
// ... existing fields ...
|
||||
PendingInvite bool // NEW: tracks if user has accepted invite
|
||||
}
|
||||
```
|
||||
|
||||
Update `ToUserInfo()` to use local field instead of IdP metadata:
|
||||
```go
|
||||
// Before
|
||||
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
|
||||
userStatus = UserStatusInvited
|
||||
}
|
||||
|
||||
// After
|
||||
if u.PendingInvite {
|
||||
userStatus = UserStatusInvited
|
||||
}
|
||||
```
|
||||
|
||||
#### `management/server/idp/idp.go`
|
||||
Simplify `Manager` interface:
|
||||
```go
|
||||
type Manager interface {
|
||||
// Simplified - no more accountID or invitedByEmail params
|
||||
CreateUser(ctx context.Context, email, name string) (*UserData, error)
|
||||
GetUserDataByID(ctx context.Context, userId string) (*UserData, error)
|
||||
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
|
||||
InviteUserByID(ctx context.Context, userID string) error
|
||||
DeleteUser(ctx context.Context, userID string) error
|
||||
}
|
||||
```
|
||||
|
||||
Remove:
|
||||
- `UpdateUserAppMetadata()` - no longer needed
|
||||
- `GetAccount()` - account membership now in NetBird DB
|
||||
- `GetAllAccounts()` - account membership now in NetBird DB
|
||||
- `AppMetadata` struct fields (`WTAccountID`, `WTPendingInvite`, `WTInvitedBy`)
|
||||
|
||||
#### `management/server/idp/zitadel.go`
|
||||
- Update `CreateUser()` to not set metadata
|
||||
- Remove `UpdateUserAppMetadata()` implementation
|
||||
- Simplify `GetUserDataByID()` to not expect metadata
|
||||
|
||||
#### `management/server/user.go`
|
||||
Update user creation flow:
|
||||
1. `inviteNewUser()`: Set `PendingInvite = true` when creating user
|
||||
2. On first login: Set `PendingInvite = false`
|
||||
|
||||
### Data Flow (After)
|
||||
|
||||
```
|
||||
Invite User:
|
||||
1. Admin invites user@example.com via NetBird UI
|
||||
2. NetBird calls Zitadel: CreateUser(email, name) // No metadata
|
||||
3. Zitadel creates user, sends invite email
|
||||
4. NetBird creates User in its DB:
|
||||
- Id = Zitadel user ID
|
||||
- AccountID = current account
|
||||
- Role = invited role
|
||||
- PendingInvite = true
|
||||
|
||||
First Login:
|
||||
1. User clicks invite, sets password in Zitadel
|
||||
2. User logs into NetBird Dashboard
|
||||
3. NetBird updates User in its DB:
|
||||
- PendingInvite = false
|
||||
- LastLogin = now
|
||||
```
|
||||
|
||||
### Deliverable
|
||||
- NetBird is single source of truth for authorization
|
||||
- Zitadel stores identity only (email, name, credentials)
|
||||
- No metadata sync issues between systems
|
||||
|
||||
---
|
||||
|
||||
## Stage 3: Remove Legacy IdP Managers
|
||||
|
||||
**Goal**: Remove all non-Zitadel IdP implementations.
|
||||
|
||||
### Files to Delete
|
||||
- `management/server/idp/auth0.go`
|
||||
- `management/server/idp/auth0_test.go`
|
||||
- `management/server/idp/azure.go`
|
||||
- `management/server/idp/azure_test.go`
|
||||
- `management/server/idp/keycloak.go`
|
||||
- `management/server/idp/keycloak_test.go`
|
||||
- `management/server/idp/okta.go`
|
||||
- `management/server/idp/okta_test.go`
|
||||
- `management/server/idp/google_workspace.go`
|
||||
- `management/server/idp/google_workspace_test.go`
|
||||
- `management/server/idp/jumpcloud.go`
|
||||
- `management/server/idp/jumpcloud_test.go`
|
||||
- `management/server/idp/authentik.go`
|
||||
- `management/server/idp/authentik_test.go`
|
||||
- `management/server/idp/pocketid.go`
|
||||
- `management/server/idp/pocketid_test.go`
|
||||
|
||||
### Files to Modify
|
||||
|
||||
#### `management/server/idp/idp.go`
|
||||
Simplify `NewManager()` to only support Zitadel:
|
||||
```go
|
||||
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
|
||||
switch strings.ToLower(config.ManagerType) {
|
||||
case "none", "":
|
||||
return nil, nil
|
||||
case "zitadel":
|
||||
return NewZitadelManager(...)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported IdP manager type: %s (only 'zitadel' is supported)", config.ManagerType)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### `management/server/idp/idp.go`
|
||||
Remove unused config structs:
|
||||
- `Auth0ClientConfig`
|
||||
- `AzureClientConfig`
|
||||
- `KeycloakClientConfig`
|
||||
- etc.
|
||||
|
||||
### Deliverable
|
||||
Only Zitadel IdP manager remains. Build passes. Tests pass.
|
||||
|
||||
---
|
||||
|
||||
## Stage 4: External IdP Connector API
|
||||
|
||||
**Goal**: API for users to add external IdPs (Okta, Google, LDAP) as Zitadel connectors.
|
||||
|
||||
### New Files
|
||||
|
||||
#### `management/server/idp/connectors.go`
|
||||
Wrapper for Zitadel IdP connector API:
|
||||
```go
|
||||
type Connector struct {
|
||||
ID string
|
||||
Name string
|
||||
Type string // "oidc", "ldap", "saml"
|
||||
Issuer string // for OIDC
|
||||
ClientID string // for OIDC
|
||||
}
|
||||
|
||||
type ConnectorManager interface {
|
||||
AddOIDCConnector(ctx, name, issuer, clientID, clientSecret string, scopes []string) (*Connector, error)
|
||||
AddLDAPConnector(ctx, name, host string, port int, baseDN, bindDN, bindPassword string) (*Connector, error)
|
||||
ListConnectors(ctx) ([]Connector, error)
|
||||
DeleteConnector(ctx, connectorID string) error
|
||||
}
|
||||
```
|
||||
|
||||
#### `management/server/http/handlers/idp_connectors_handler.go`
|
||||
REST API handlers:
|
||||
```
|
||||
POST /api/idp/connectors - Add connector
|
||||
GET /api/idp/connectors - List connectors
|
||||
DELETE /api/idp/connectors/{id} - Delete connector
|
||||
```
|
||||
|
||||
### Zitadel API Endpoints Used
|
||||
- `POST /management/v1/idps/generic_oidc` - Add OIDC connector
|
||||
- `POST /management/v1/idps/ldap` - Add LDAP connector
|
||||
- `GET /management/v1/idps` - List all IdPs
|
||||
- `DELETE /management/v1/idps/{id}` - Remove IdP
|
||||
|
||||
### Deliverable
|
||||
Admin users can add/remove external IdP connectors via NetBird API.
|
||||
|
||||
---
|
||||
|
||||
## Stage 5: User Role Permissions
|
||||
|
||||
**Goal**: Enforce admin/user permissions throughout the application.
|
||||
|
||||
### Files to Modify
|
||||
|
||||
#### `management/server/types/user.go`
|
||||
Simplify roles to:
|
||||
```go
|
||||
const (
|
||||
UserRoleAdmin UserRole = "admin"
|
||||
UserRoleUser UserRole = "user"
|
||||
)
|
||||
```
|
||||
|
||||
Keep `owner` as alias for `admin` for backwards compatibility.
|
||||
|
||||
#### Permission checks (various files)
|
||||
Update permission checks to use simplified role model:
|
||||
- Admin: Full access to all resources
|
||||
- User: View/manage own peers only, no user/policy management
|
||||
|
||||
### Permission Matrix
|
||||
| Resource | Admin | User |
|
||||
|----------|-------|------|
|
||||
| All peers | read/write | - |
|
||||
| Own peers | read/write | read/write |
|
||||
| Users | read/write | - |
|
||||
| Groups | read/write | read (own) |
|
||||
| Policies | read/write | - |
|
||||
| IdP Connectors | read/write | - |
|
||||
| Account settings | read/write | - |
|
||||
|
||||
### Deliverable
|
||||
Role-based access control enforced. User role has limited dashboard access.
|
||||
|
||||
---
|
||||
|
||||
## Architecture Diagram
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ NETBIRD DEPLOYMENT │
|
||||
│ │
|
||||
│ ┌───────────────────────────────────────────────────────────┐ │
|
||||
│ │ CADDY (Reverse Proxy) │ │
|
||||
│ │ :80/:443 → routes to appropriate service │ │
|
||||
│ └───────────────────────────────────────────────────────────┘ │
|
||||
│ │ │ │ │ │
|
||||
│ ▼ ▼ ▼ ▼ │
|
||||
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────────┐ │
|
||||
│ │ Dashboard│ │Management│ │ Signal │ │ Zitadel │ │
|
||||
│ │ :80 │ │ :80 │ │ :80 │ │ :8080 (SQLite) │ │
|
||||
│ └──────────┘ └────┬─────┘ └──────────┘ └──────────────────┘ │
|
||||
│ │ │ │
|
||||
│ │ OIDC + Mgmt API │ │
|
||||
│ └──────────────────────────┘ │
|
||||
│ │
|
||||
│ ┌───────────────────────────────────────────────────────────┐ │
|
||||
│ │ Relay │ Coturn (TURN) │ │
|
||||
│ └───────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Implementation Order
|
||||
|
||||
1. **Stage 1** - Infrastructure (can be tested standalone)
|
||||
2. **Stage 2** - Role management (requires Stage 1)
|
||||
3. **Stage 3** - Remove legacy IdPs (can be done in parallel with Stage 2)
|
||||
4. **Stage 4** - Connector API (requires Stage 1)
|
||||
5. **Stage 5** - Permissions (requires Stage 2)
|
||||
|
||||
Stages 3 and 4 can be worked on in parallel after Stage 1 is complete.
|
||||
|
||||
---
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### First-Boot Credentials
|
||||
- Generated with `openssl rand -base64 32` in configure.sh
|
||||
- Printed to stdout during configuration (save them!)
|
||||
- Marked as "must change on first login" in Zitadel via `PASSWORDCHANGEREQUIRED=true`
|
||||
|
||||
### Zitadel Masterkey
|
||||
- Auto-generated if not provided (32 bytes)
|
||||
- Stored in environment variable (docker secret in production)
|
||||
- Used for Zitadel's internal encryption
|
||||
|
||||
### Service Account
|
||||
- Machine user created via `ZITADEL_FIRSTINSTANCE_ORG_MACHINE_*` env vars
|
||||
- PAT (Personal Access Token) generated with 1-year expiration
|
||||
- Used by NetBird management to call Zitadel API
|
||||
|
||||
---
|
||||
|
||||
## Migration Notes
|
||||
|
||||
Existing deployments using Auth0/Azure/Keycloak will need to:
|
||||
1. Deploy Zitadel alongside existing IdP
|
||||
2. Add existing IdP as Zitadel connector
|
||||
3. Migrate users (or let them re-authenticate via connector)
|
||||
4. Switch NetBird config to use Zitadel
|
||||
5. Remove old IdP
|
||||
|
||||
A migration guide should be provided as documentation.
|
||||
83
infrastructure_files/Caddyfile.tmpl
Normal file
83
infrastructure_files/Caddyfile.tmpl
Normal file
@@ -0,0 +1,83 @@
|
||||
{
|
||||
servers :80,:443 {
|
||||
protocols h1 h2c h2 h3
|
||||
}
|
||||
}
|
||||
|
||||
(security_headers) {
|
||||
header * {
|
||||
# HSTS - use 1 hour for testing, increase to 63072000 (2 years) in production
|
||||
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
|
||||
# Prevent MIME type sniffing
|
||||
X-Content-Type-Options "nosniff"
|
||||
# Clickjacking protection
|
||||
X-Frame-Options "SAMEORIGIN"
|
||||
# XSS protection
|
||||
X-XSS-Protection "1; mode=block"
|
||||
# Remove server header
|
||||
-Server
|
||||
# Referrer policy
|
||||
Referrer-Policy strict-origin-when-cross-origin
|
||||
}
|
||||
}
|
||||
|
||||
:${NETBIRD_CADDY_PORT}${CADDY_SECURE_DOMAIN} {
|
||||
import security_headers
|
||||
|
||||
# Relay
|
||||
reverse_proxy /relay* relay:${NETBIRD_RELAY_INTERNAL_PORT}
|
||||
|
||||
# Signal - WebSocket proxy
|
||||
reverse_proxy /ws-proxy/signal* signal:80
|
||||
# Signal - gRPC
|
||||
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
|
||||
|
||||
# Management - REST API
|
||||
reverse_proxy /api/* management:80
|
||||
# Management - WebSocket proxy
|
||||
reverse_proxy /ws-proxy/management* management:80
|
||||
# Management - gRPC
|
||||
reverse_proxy /management.ManagementService/* h2c://management:80
|
||||
|
||||
# Zitadel - Admin API
|
||||
reverse_proxy /zitadel.admin.v1.AdminService/* h2c://zitadel:8080
|
||||
reverse_proxy /admin/v1/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - Auth API
|
||||
reverse_proxy /zitadel.auth.v1.AuthService/* h2c://zitadel:8080
|
||||
reverse_proxy /auth/v1/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - Management API
|
||||
reverse_proxy /zitadel.management.v1.ManagementService/* h2c://zitadel:8080
|
||||
reverse_proxy /management/v1/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - System API
|
||||
reverse_proxy /zitadel.system.v1.SystemService/* h2c://zitadel:8080
|
||||
reverse_proxy /system/v1/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - User API v2
|
||||
reverse_proxy /zitadel.user.v2.UserService/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - Assets
|
||||
reverse_proxy /assets/v1/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - UI (login, console, etc.)
|
||||
reverse_proxy /ui/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - OIDC endpoints
|
||||
reverse_proxy /oidc/v1/* h2c://zitadel:8080
|
||||
reverse_proxy /oauth/v2/* h2c://zitadel:8080
|
||||
reverse_proxy /.well-known/openid-configuration h2c://zitadel:8080
|
||||
|
||||
# Zitadel - SAML
|
||||
reverse_proxy /saml/v2/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - Other
|
||||
reverse_proxy /openapi/* h2c://zitadel:8080
|
||||
reverse_proxy /debug/* h2c://zitadel:8080
|
||||
reverse_proxy /device/* h2c://zitadel:8080
|
||||
reverse_proxy /device h2c://zitadel:8080
|
||||
|
||||
# Dashboard - catch-all for frontend
|
||||
reverse_proxy /* dashboard:80
|
||||
}
|
||||
@@ -2,29 +2,20 @@
|
||||
|
||||
# Management API
|
||||
|
||||
# Management API port
|
||||
NETBIRD_MGMT_API_PORT=${NETBIRD_MGMT_API_PORT:-33073}
|
||||
# Management API endpoint address, used by the Dashboard
|
||||
NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
|
||||
# Management Certificate file path. These are generated by the Dashboard container
|
||||
NETBIRD_LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN
|
||||
NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/fullchain.pem"
|
||||
# Management Certificate key file path.
|
||||
NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/privkey.pem"
|
||||
# Management API endpoint address, used by the Dashboard (Caddy handles TLS)
|
||||
NETBIRD_MGMT_API_ENDPOINT=${NETBIRD_HTTP_PROTOCOL:-https}://$NETBIRD_DOMAIN
|
||||
# By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain
|
||||
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
|
||||
NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted}
|
||||
NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=${NETBIRD_MGMT_IDP_SIGNKEY_REFRESH:-false}
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=${NETBIRD_MGMT_DISABLE_DEFAULT_POLICY:-false}
|
||||
|
||||
# Signal
|
||||
NETBIRD_SIGNAL_PROTOCOL="http"
|
||||
NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-10000}
|
||||
NETBIRD_SIGNAL_PROTOCOL=${NETBIRD_HTTP_PROTOCOL:-https}
|
||||
NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-443}
|
||||
|
||||
# Relay
|
||||
NETBIRD_RELAY_DOMAIN=${NETBIRD_RELAY_DOMAIN:-$NETBIRD_DOMAIN}
|
||||
NETBIRD_RELAY_PORT=${NETBIRD_RELAY_PORT:-33080}
|
||||
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT:-rel://$NETBIRD_RELAY_DOMAIN:$NETBIRD_RELAY_PORT}
|
||||
# Relay (internal port for Caddy reverse proxy)
|
||||
NETBIRD_RELAY_INTERNAL_PORT=${NETBIRD_RELAY_INTERNAL_PORT:-80}
|
||||
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT:-${NETBIRD_RELAY_PROTO:-rels}://$NETBIRD_DOMAIN:${NETBIRD_RELAY_PORT:-443}}
|
||||
# Relay auth secret
|
||||
NETBIRD_RELAY_AUTH_SECRET=
|
||||
|
||||
@@ -141,3 +132,57 @@ export NETBIRD_RELAY_ENDPOINT
|
||||
export NETBIRD_RELAY_AUTH_SECRET
|
||||
export NETBIRD_RELAY_TAG
|
||||
export NETBIRD_MGMT_DISABLE_DEFAULT_POLICY
|
||||
|
||||
# Zitadel IdP Configuration
|
||||
ZITADEL_TAG=${ZITADEL_TAG:-"v4.7.6"}
|
||||
# Zitadel masterkey (32 bytes, auto-generated if not set)
|
||||
ZITADEL_MASTERKEY=
|
||||
# Zitadel admin credentials (auto-generated if not set)
|
||||
ZITADEL_ADMIN_USERNAME=
|
||||
ZITADEL_ADMIN_PASSWORD=
|
||||
# Zitadel external configuration
|
||||
ZITADEL_EXTERNALSECURE=${ZITADEL_EXTERNALSECURE:-true}
|
||||
ZITADEL_EXTERNALPORT=${ZITADEL_EXTERNALPORT:-443}
|
||||
ZITADEL_TLS_MODE=${ZITADEL_TLS_MODE:-external}
|
||||
# Zitadel PAT expiration (1 year from startup)
|
||||
ZITADEL_PAT_EXPIRATION=
|
||||
# Zitadel management endpoint
|
||||
ZITADEL_MANAGEMENT_ENDPOINT=${NETBIRD_HTTP_PROTOCOL:-https}://$NETBIRD_DOMAIN/management/v1
|
||||
# HTTP protocol (http or https)
|
||||
NETBIRD_HTTP_PROTOCOL=${NETBIRD_HTTP_PROTOCOL:-https}
|
||||
# Caddy configuration
|
||||
NETBIRD_CADDY_PORT=${NETBIRD_CADDY_PORT:-80}
|
||||
CADDY_SECURE_DOMAIN=
|
||||
|
||||
# Zitadel OIDC endpoints
|
||||
NETBIRD_AUTH_AUTHORITY=${NETBIRD_HTTP_PROTOCOL:-https}://$NETBIRD_DOMAIN
|
||||
NETBIRD_AUTH_TOKEN_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/oauth/v2/token
|
||||
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/.well-known/openid-configuration
|
||||
NETBIRD_AUTH_JWT_CERTS=${NETBIRD_AUTH_AUTHORITY}/.well-known/jwks.json
|
||||
NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/oauth/v2/authorize
|
||||
NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/oauth/v2/device_authorization
|
||||
NETBIRD_AUTH_USER_ID_CLAIM=${NETBIRD_AUTH_USER_ID_CLAIM:-sub}
|
||||
NETBIRD_AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES:-"openid profile email offline_access"}
|
||||
|
||||
# Zitadel exports
|
||||
export ZITADEL_TAG
|
||||
export ZITADEL_MASTERKEY
|
||||
export ZITADEL_ADMIN_USERNAME
|
||||
export ZITADEL_ADMIN_PASSWORD
|
||||
export ZITADEL_EXTERNALSECURE
|
||||
export ZITADEL_EXTERNALPORT
|
||||
export ZITADEL_TLS_MODE
|
||||
export ZITADEL_PAT_EXPIRATION
|
||||
export ZITADEL_MANAGEMENT_ENDPOINT
|
||||
export NETBIRD_HTTP_PROTOCOL
|
||||
export NETBIRD_CADDY_PORT
|
||||
export CADDY_SECURE_DOMAIN
|
||||
export NETBIRD_AUTH_AUTHORITY
|
||||
export NETBIRD_AUTH_TOKEN_ENDPOINT
|
||||
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT
|
||||
export NETBIRD_AUTH_JWT_CERTS
|
||||
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT
|
||||
export NETBIRD_AUTH_USER_ID_CLAIM
|
||||
export NETBIRD_AUTH_SUPPORTED_SCOPES
|
||||
export NETBIRD_RELAY_INTERNAL_PORT
|
||||
|
||||
@@ -1,261 +1,137 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if ! which curl >/dev/null 2>&1; then
|
||||
echo "This script uses curl fetch OpenID configuration from IDP."
|
||||
echo "Please install curl and re-run the script https://curl.se/"
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! which jq >/dev/null 2>&1; then
|
||||
echo "This script uses jq to load OpenID configuration from IDP."
|
||||
echo "Please install jq and re-run the script https://stedolan.github.io/jq/"
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
# Check required dependencies
|
||||
for cmd in curl jq envsubst openssl; do
|
||||
if ! which $cmd >/dev/null 2>&1; then
|
||||
echo "This script requires $cmd. Please install it and re-run."
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
# Source configuration
|
||||
source setup.env
|
||||
source base.setup.env
|
||||
|
||||
if ! which envsubst >/dev/null 2>&1; then
|
||||
echo "envsubst is needed to run this script"
|
||||
if [[ $(uname) == "Darwin" ]]; then
|
||||
echo "you can install it with homebrew (https://brew.sh):"
|
||||
echo "brew install gettext"
|
||||
else
|
||||
if which apt-get >/dev/null 2>&1; then
|
||||
echo "you can install it by running"
|
||||
echo "apt-get update && apt-get install gettext-base"
|
||||
else
|
||||
echo "you can install it by installing the package gettext with your package manager"
|
||||
fi
|
||||
fi
|
||||
# Validate required variables
|
||||
if [[ -z "$NETBIRD_DOMAIN" ]]; then
|
||||
echo "NETBIRD_DOMAIN is not set, please update your setup.env file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "x-$NETBIRD_DOMAIN" == "x-" ]]; then
|
||||
echo NETBIRD_DOMAIN is not set, please update your setup.env file
|
||||
echo If you are migrating from old versions, you might need to update your variables prefixes from
|
||||
echo WIRETRUSTEE_.. TO NETBIRD_
|
||||
# Check database configuration if using external database
|
||||
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "postgres" && -z "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" ]]; then
|
||||
echo "Error: NETBIRD_STORE_CONFIG_ENGINE=postgres but NETBIRD_STORE_ENGINE_POSTGRES_DSN is not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if PostgreSQL is set as the store engine
|
||||
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "postgres" ]]; then
|
||||
# Exit if 'NETBIRD_STORE_ENGINE_POSTGRES_DSN' is not set
|
||||
if [[ -z "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" ]]; then
|
||||
echo "Warning: NETBIRD_STORE_CONFIG_ENGINE=postgres but NETBIRD_STORE_ENGINE_POSTGRES_DSN is not set."
|
||||
echo "Please add the following line to your setup.env file:"
|
||||
echo 'NETBIRD_STORE_ENGINE_POSTGRES_DSN="host=<PG_HOST> user=<PG_USER> password=<PG_PASSWORD> dbname=<PG_DB_NAME> port=<PG_PORT>"'
|
||||
exit 1
|
||||
fi
|
||||
export NETBIRD_STORE_ENGINE_POSTGRES_DSN
|
||||
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "mysql" && -z "$NETBIRD_STORE_ENGINE_MYSQL_DSN" ]]; then
|
||||
echo "Error: NETBIRD_STORE_CONFIG_ENGINE=mysql but NETBIRD_STORE_ENGINE_MYSQL_DSN is not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if MySQL is set as the store engine
|
||||
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "mysql" ]]; then
|
||||
# Exit if 'NETBIRD_STORE_ENGINE_MYSQL_DSN' is not set
|
||||
if [[ -z "$NETBIRD_STORE_ENGINE_MYSQL_DSN" ]]; then
|
||||
echo "Warning: NETBIRD_STORE_CONFIG_ENGINE=mysql but NETBIRD_STORE_ENGINE_MYSQL_DSN is not set."
|
||||
echo "Please add the following line to your setup.env file:"
|
||||
echo 'NETBIRD_STORE_ENGINE_MYSQL_DSN="<username>:<password>@tcp(127.0.0.1:3306)/<database>"'
|
||||
exit 1
|
||||
fi
|
||||
export NETBIRD_STORE_ENGINE_MYSQL_DSN
|
||||
fi
|
||||
|
||||
# local development or tests
|
||||
# Configure for local development vs production
|
||||
if [[ $NETBIRD_DOMAIN == "localhost" || $NETBIRD_DOMAIN == "127.0.0.1" ]]; then
|
||||
export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN="netbird.selfhosted"
|
||||
export NETBIRD_MGMT_API_ENDPOINT=http://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
|
||||
unset NETBIRD_MGMT_API_CERT_FILE
|
||||
unset NETBIRD_MGMT_API_CERT_KEY_FILE
|
||||
fi
|
||||
|
||||
# if not provided, we generate a turn password
|
||||
if [[ "x-$TURN_PASSWORD" == "x-" ]]; then
|
||||
export TURN_PASSWORD=$(openssl rand -base64 32 | sed 's/=//g')
|
||||
fi
|
||||
|
||||
TURN_EXTERNAL_IP_CONFIG="#"
|
||||
|
||||
if [[ "x-$NETBIRD_TURN_EXTERNAL_IP" == "x-" ]]; then
|
||||
echo "discovering server's public IP"
|
||||
IP=$(curl -s -4 https://jsonip.com | jq -r '.ip')
|
||||
if [[ "x-$IP" != "x-" ]]; then
|
||||
TURN_EXTERNAL_IP_CONFIG="external-ip=$IP"
|
||||
else
|
||||
echo "unable to discover server's public IP"
|
||||
fi
|
||||
export NETBIRD_MGMT_API_ENDPOINT="http://$NETBIRD_DOMAIN"
|
||||
export NETBIRD_HTTP_PROTOCOL="http"
|
||||
export ZITADEL_EXTERNALSECURE="false"
|
||||
export ZITADEL_EXTERNALPORT="80"
|
||||
export ZITADEL_TLS_MODE="disabled"
|
||||
export NETBIRD_RELAY_PROTO="rel"
|
||||
else
|
||||
echo "${NETBIRD_TURN_EXTERNAL_IP}"| egrep '([0-9]{1,3}\.){3}[0-9]{1,3}$' > /dev/null
|
||||
if [[ $? -eq 0 ]]; then
|
||||
echo "using provided server's public IP"
|
||||
TURN_EXTERNAL_IP_CONFIG="external-ip=$NETBIRD_TURN_EXTERNAL_IP"
|
||||
else
|
||||
echo "provided NETBIRD_TURN_EXTERNAL_IP $NETBIRD_TURN_EXTERNAL_IP is invalid, please correct it and try again"
|
||||
exit 1
|
||||
fi
|
||||
export NETBIRD_HTTP_PROTOCOL="https"
|
||||
export ZITADEL_EXTERNALSECURE="true"
|
||||
export ZITADEL_EXTERNALPORT="443"
|
||||
export ZITADEL_TLS_MODE="external"
|
||||
export NETBIRD_RELAY_PROTO="rels"
|
||||
export CADDY_SECURE_DOMAIN=", $NETBIRD_DOMAIN:443"
|
||||
fi
|
||||
|
||||
# Auto-generate secrets if not provided
|
||||
[[ -z "$TURN_PASSWORD" ]] && export TURN_PASSWORD=$(openssl rand -base64 32 | sed 's/=//g')
|
||||
[[ -z "$NETBIRD_RELAY_AUTH_SECRET" ]] && export NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed 's/=//g')
|
||||
[[ -z "$ZITADEL_MASTERKEY" ]] && export ZITADEL_MASTERKEY=$(openssl rand -base64 32 | head -c 32)
|
||||
|
||||
# Generate Zitadel admin credentials if not provided
|
||||
if [[ -z "$ZITADEL_ADMIN_USERNAME" ]]; then
|
||||
export ZITADEL_ADMIN_USERNAME="admin@${NETBIRD_DOMAIN}"
|
||||
fi
|
||||
if [[ -z "$ZITADEL_ADMIN_PASSWORD" ]]; then
|
||||
export ZITADEL_ADMIN_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')!"
|
||||
fi
|
||||
|
||||
# Set Zitadel PAT expiration (1 year from now)
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
export ZITADEL_PAT_EXPIRATION=$(date -u -v+1y "+%Y-%m-%dT%H:%M:%SZ")
|
||||
else
|
||||
export ZITADEL_PAT_EXPIRATION=$(date -u -d "+1 year" "+%Y-%m-%dT%H:%M:%SZ")
|
||||
fi
|
||||
|
||||
# Discover external IP for TURN
|
||||
TURN_EXTERNAL_IP_CONFIG="#"
|
||||
if [[ -z "$NETBIRD_TURN_EXTERNAL_IP" ]]; then
|
||||
IP=$(curl -s -4 https://jsonip.com | jq -r '.ip' 2>/dev/null || echo "")
|
||||
[[ -n "$IP" ]] && TURN_EXTERNAL_IP_CONFIG="external-ip=$IP"
|
||||
elif echo "$NETBIRD_TURN_EXTERNAL_IP" | grep -qE '^([0-9]{1,3}\.){3}[0-9]{1,3}$'; then
|
||||
TURN_EXTERNAL_IP_CONFIG="external-ip=$NETBIRD_TURN_EXTERNAL_IP"
|
||||
fi
|
||||
export TURN_EXTERNAL_IP_CONFIG
|
||||
|
||||
# if not provided, we generate a relay auth secret
|
||||
if [[ "x-$NETBIRD_RELAY_AUTH_SECRET" == "x-" ]]; then
|
||||
export NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed 's/=//g')
|
||||
fi
|
||||
|
||||
artifacts_path="./artifacts"
|
||||
mkdir -p $artifacts_path
|
||||
# Configure endpoints
|
||||
export NETBIRD_AUTH_AUTHORITY="${NETBIRD_HTTP_PROTOCOL}://${NETBIRD_DOMAIN}"
|
||||
export NETBIRD_AUTH_TOKEN_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/oauth/v2/token"
|
||||
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/.well-known/openid-configuration"
|
||||
export NETBIRD_AUTH_JWT_CERTS="${NETBIRD_AUTH_AUTHORITY}/.well-known/jwks.json"
|
||||
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/oauth/v2/authorize"
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/oauth/v2/device_authorization"
|
||||
export ZITADEL_MANAGEMENT_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/management/v1"
|
||||
export NETBIRD_RELAY_ENDPOINT="${NETBIRD_RELAY_PROTO}://${NETBIRD_DOMAIN}:${ZITADEL_EXTERNALPORT}"
|
||||
|
||||
# Volume names (with backwards compatibility)
|
||||
MGMT_VOLUMENAME="${VOLUME_PREFIX}${MGMT_VOLUMESUFFIX}"
|
||||
SIGNAL_VOLUMENAME="${VOLUME_PREFIX}${SIGNAL_VOLUMESUFFIX}"
|
||||
LETSENCRYPT_VOLUMENAME="${VOLUME_PREFIX}${LETSENCRYPT_VOLUMESUFFIX}"
|
||||
# if volume with wiretrustee- prefix already exists, use it, else create new with netbird-
|
||||
OLD_PREFIX='wiretrustee-'
|
||||
if docker volume ls | grep -q "${OLD_PREFIX}${MGMT_VOLUMESUFFIX}"; then
|
||||
MGMT_VOLUMENAME="${OLD_PREFIX}${MGMT_VOLUMESUFFIX}"
|
||||
fi
|
||||
if docker volume ls | grep -q "${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}"; then
|
||||
SIGNAL_VOLUMENAME="${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}"
|
||||
fi
|
||||
if docker volume ls | grep -q "${OLD_PREFIX}${LETSENCRYPT_VOLUMESUFFIX}"; then
|
||||
LETSENCRYPT_VOLUMENAME="${OLD_PREFIX}${LETSENCRYPT_VOLUMESUFFIX}"
|
||||
docker volume ls 2>/dev/null | grep -q "${OLD_PREFIX}${MGMT_VOLUMESUFFIX}" && MGMT_VOLUMENAME="${OLD_PREFIX}${MGMT_VOLUMESUFFIX}"
|
||||
docker volume ls 2>/dev/null | grep -q "${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}" && SIGNAL_VOLUMENAME="${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}"
|
||||
export MGMT_VOLUMENAME SIGNAL_VOLUMENAME
|
||||
|
||||
# Preserve existing encryption key
|
||||
if test -f 'management.json'; then
|
||||
encKey=$(jq -r ".DataStoreEncryptionKey" management.json 2>/dev/null || echo "null")
|
||||
[[ "$encKey" != "null" && -n "$encKey" ]] && export NETBIRD_DATASTORE_ENC_KEY="$encKey"
|
||||
fi
|
||||
|
||||
export MGMT_VOLUMENAME
|
||||
export SIGNAL_VOLUMENAME
|
||||
export LETSENCRYPT_VOLUMENAME
|
||||
|
||||
#backwards compatibility after migrating to generic OIDC with Auth0
|
||||
if [[ -z "${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT}" ]]; then
|
||||
|
||||
if [[ -z "${NETBIRD_AUTH0_DOMAIN}" ]]; then
|
||||
# not a backward compatible state
|
||||
echo "NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT property must be set in the setup.env file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "It seems like you provided an old setup.env file."
|
||||
echo "Since the release of v0.8.10, we introduced a new set of properties."
|
||||
echo "The script is backward compatible and will continue automatically."
|
||||
echo "In the future versions it will be deprecated. Please refer to the documentation to learn about the changes http://netbird.io/docs/getting-started/self-hosting"
|
||||
|
||||
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="https://${NETBIRD_AUTH0_DOMAIN}/.well-known/openid-configuration"
|
||||
export NETBIRD_USE_AUTH0="true"
|
||||
export NETBIRD_AUTH_AUDIENCE=${NETBIRD_AUTH0_AUDIENCE}
|
||||
export NETBIRD_AUTH_CLIENT_ID=${NETBIRD_AUTH0_CLIENT_ID}
|
||||
fi
|
||||
|
||||
echo "loading OpenID configuration from ${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT} to the openid-configuration.json file"
|
||||
curl "${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT}" -q -o ${artifacts_path}/openid-configuration.json
|
||||
|
||||
export NETBIRD_AUTH_AUTHORITY=$(jq -r '.issuer' ${artifacts_path}/openid-configuration.json)
|
||||
export NETBIRD_AUTH_JWT_CERTS=$(jq -r '.jwks_uri' ${artifacts_path}/openid-configuration.json)
|
||||
export NETBIRD_AUTH_TOKEN_ENDPOINT=$(jq -r '.token_endpoint' ${artifacts_path}/openid-configuration.json)
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=$(jq -r '.device_authorization_endpoint' ${artifacts_path}/openid-configuration.json)
|
||||
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT=$(jq -r '.authorization_endpoint' ${artifacts_path}/openid-configuration.json)
|
||||
|
||||
if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then
|
||||
# user enabled Device Authorization Grant feature
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted"
|
||||
fi
|
||||
|
||||
if [ "$NETBIRD_TOKEN_SOURCE" = "idToken" ]; then
|
||||
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN=true
|
||||
fi
|
||||
|
||||
# Check if letsencrypt was disabled
|
||||
if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then
|
||||
export NETBIRD_DASHBOARD_ENDPOINT="https://$NETBIRD_DOMAIN:443"
|
||||
export NETBIRD_SIGNAL_ENDPOINT="https://$NETBIRD_DOMAIN:$NETBIRD_SIGNAL_PORT"
|
||||
export NETBIRD_RELAY_ENDPOINT="rels://$NETBIRD_DOMAIN:$NETBIRD_RELAY_PORT/relay"
|
||||
|
||||
echo "Letsencrypt was disabled, the Https-endpoints cannot be used anymore"
|
||||
echo " and a reverse-proxy with Https needs to be placed in front of netbird!"
|
||||
echo "The following forwards have to be setup:"
|
||||
echo "- $NETBIRD_DASHBOARD_ENDPOINT -http-> dashboard:80"
|
||||
echo "- $NETBIRD_MGMT_API_ENDPOINT/api -http-> management:$NETBIRD_MGMT_API_PORT"
|
||||
echo "- $NETBIRD_MGMT_API_ENDPOINT/management.ManagementService/ -grpc-> management:$NETBIRD_MGMT_API_PORT"
|
||||
echo "- $NETBIRD_SIGNAL_ENDPOINT/signalexchange.SignalExchange/ -grpc-> signal:80"
|
||||
echo "- $NETBIRD_RELAY_ENDPOINT/ -http-> relay:33080"
|
||||
echo "You most likely also have to change NETBIRD_MGMT_API_ENDPOINT in base.setup.env and port-mappings in docker-compose.yml.tmpl and rerun this script."
|
||||
echo " The target of the forwards depends on your setup. Beware of the gRPC protocol instead of http for management and signal!"
|
||||
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
|
||||
echo ""
|
||||
|
||||
unset NETBIRD_LETSENCRYPT_DOMAIN
|
||||
unset NETBIRD_MGMT_API_CERT_FILE
|
||||
unset NETBIRD_MGMT_API_CERT_KEY_FILE
|
||||
fi
|
||||
|
||||
if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then
|
||||
export NETBIRD_SIGNAL_PROTOCOL="https"
|
||||
fi
|
||||
|
||||
# Check if management identity provider is set
|
||||
if [ -n "$NETBIRD_MGMT_IDP" ]; then
|
||||
EXTRA_CONFIG={}
|
||||
|
||||
# extract extra config from all env prefixed with NETBIRD_IDP_MGMT_EXTRA_
|
||||
for var in ${!NETBIRD_IDP_MGMT_EXTRA_*}; do
|
||||
# convert key snake case to camel case
|
||||
key=$(
|
||||
echo "${var#NETBIRD_IDP_MGMT_EXTRA_}" | awk -F "_" \
|
||||
'{for (i=1; i<=NF; i++) {output=output substr($i,1,1) tolower(substr($i,2))} print output}'
|
||||
)
|
||||
value="${!var}"
|
||||
|
||||
echo "$var"
|
||||
EXTRA_CONFIG=$(jq --arg k "$key" --arg v "$value" '.[$k] = $v' <<<"$EXTRA_CONFIG")
|
||||
done
|
||||
|
||||
export NETBIRD_MGMT_IDP
|
||||
export NETBIRD_IDP_MGMT_CLIENT_ID
|
||||
export NETBIRD_IDP_MGMT_CLIENT_SECRET
|
||||
export NETBIRD_IDP_MGMT_EXTRA_CONFIG=$EXTRA_CONFIG
|
||||
else
|
||||
export NETBIRD_IDP_MGMT_EXTRA_CONFIG={}
|
||||
fi
|
||||
|
||||
IFS=',' read -r -a REDIRECT_URL_PORTS <<< "$NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS"
|
||||
REDIRECT_URLS=""
|
||||
for port in "${REDIRECT_URL_PORTS[@]}"; do
|
||||
REDIRECT_URLS+="\"http://localhost:${port}\","
|
||||
# Create artifacts directory and backup existing files
|
||||
artifacts_path="./artifacts"
|
||||
mkdir -p "$artifacts_path"
|
||||
bkp_postfix="$(date +%s)"
|
||||
for file in docker-compose.yml management.json turnserver.conf Caddyfile; do
|
||||
[[ -f "${artifacts_path}/${file}" ]] && cp "${artifacts_path}/${file}" "${artifacts_path}/${file}.bkp.${bkp_postfix}"
|
||||
done
|
||||
|
||||
export NETBIRD_AUTH_PKCE_REDIRECT_URLS=${REDIRECT_URLS%,}
|
||||
# Generate configuration files
|
||||
envsubst < docker-compose.yml.tmpl > "$artifacts_path/docker-compose.yml"
|
||||
envsubst < management.json.tmpl | jq . > "$artifacts_path/management.json"
|
||||
envsubst < turnserver.conf.tmpl > "$artifacts_path/turnserver.conf"
|
||||
envsubst < Caddyfile.tmpl > "$artifacts_path/Caddyfile"
|
||||
|
||||
# Remove audience for providers that do not support it
|
||||
if [ "$NETBIRD_DASH_AUTH_USE_AUDIENCE" = "false" ]; then
|
||||
export NETBIRD_DASH_AUTH_AUDIENCE=none
|
||||
export NETBIRD_AUTH_PKCE_AUDIENCE=
|
||||
fi
|
||||
|
||||
# Read the encryption key
|
||||
if test -f 'management.json'; then
|
||||
encKey=$(jq -r ".DataStoreEncryptionKey" management.json)
|
||||
if [[ "$encKey" != "null" ]]; then
|
||||
export NETBIRD_DATASTORE_ENC_KEY=$encKey
|
||||
|
||||
fi
|
||||
fi
|
||||
|
||||
env | grep NETBIRD
|
||||
|
||||
bkp_postfix="$(date +%s)"
|
||||
if test -f "${artifacts_path}/docker-compose.yml"; then
|
||||
cp $artifacts_path/docker-compose.yml "${artifacts_path}/docker-compose.yml.bkp.${bkp_postfix}"
|
||||
fi
|
||||
|
||||
if test -f "${artifacts_path}/management.json"; then
|
||||
cp $artifacts_path/management.json "${artifacts_path}/management.json.bkp.${bkp_postfix}"
|
||||
fi
|
||||
|
||||
if test -f "${artifacts_path}/turnserver.conf"; then
|
||||
cp ${artifacts_path}/turnserver.conf "${artifacts_path}/turnserver.conf.bkp.${bkp_postfix}"
|
||||
fi
|
||||
envsubst <docker-compose.yml.tmpl >$artifacts_path/docker-compose.yml
|
||||
envsubst <management.json.tmpl | jq . >$artifacts_path/management.json
|
||||
envsubst <turnserver.conf.tmpl >$artifacts_path/turnserver.conf
|
||||
# Print summary
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo " NetBird Configuration Complete"
|
||||
echo "=========================================="
|
||||
echo " Domain: $NETBIRD_DOMAIN"
|
||||
echo " Protocol: $NETBIRD_HTTP_PROTOCOL"
|
||||
echo " Zitadel: $ZITADEL_TAG (SQLite)"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo " ADMIN CREDENTIALS (save these!):"
|
||||
echo " Username: $ZITADEL_ADMIN_USERNAME"
|
||||
echo " Password: $ZITADEL_ADMIN_PASSWORD"
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "To start NetBird:"
|
||||
echo " cd $artifacts_path && docker compose up -d"
|
||||
echo ""
|
||||
|
||||
@@ -7,108 +7,137 @@ x-default: &default
|
||||
max-file: '2'
|
||||
|
||||
services:
|
||||
# Caddy reverse proxy
|
||||
caddy:
|
||||
<<: *default
|
||||
image: caddy:2
|
||||
networks: [netbird]
|
||||
ports:
|
||||
- '443:443'
|
||||
- '443:443/udp'
|
||||
- '80:80'
|
||||
volumes:
|
||||
- netbird-caddy-data:/data
|
||||
- ./Caddyfile:/etc/caddy/Caddyfile:ro
|
||||
depends_on:
|
||||
zitadel:
|
||||
condition: service_healthy
|
||||
|
||||
# UI dashboard
|
||||
dashboard:
|
||||
<<: *default
|
||||
image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG
|
||||
ports:
|
||||
- 80:80
|
||||
- 443:443
|
||||
networks: [netbird]
|
||||
environment:
|
||||
# Endpoints
|
||||
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
||||
- NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
||||
# OIDC
|
||||
- AUTH_AUDIENCE=$NETBIRD_DASH_AUTH_AUDIENCE
|
||||
- AUTH_AUDIENCE=$NETBIRD_AUTH_CLIENT_ID
|
||||
- AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
|
||||
- AUTH_CLIENT_SECRET=$NETBIRD_AUTH_CLIENT_SECRET
|
||||
- AUTH_CLIENT_SECRET=
|
||||
- AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY
|
||||
- USE_AUTH0=$NETBIRD_USE_AUTH0
|
||||
- USE_AUTH0=false
|
||||
- AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES
|
||||
- AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI
|
||||
- AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI
|
||||
- AUTH_REDIRECT_URI=/nb-auth
|
||||
- AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
|
||||
- NETBIRD_TOKEN_SOURCE=$NETBIRD_TOKEN_SOURCE
|
||||
# SSL
|
||||
- NGINX_SSL_PORT=443
|
||||
# Letsencrypt
|
||||
- LETSENCRYPT_DOMAIN=$NETBIRD_LETSENCRYPT_DOMAIN
|
||||
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
|
||||
volumes:
|
||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
|
||||
- LETSENCRYPT_DOMAIN=none
|
||||
depends_on:
|
||||
zitadel:
|
||||
condition: service_healthy
|
||||
|
||||
# Signal
|
||||
signal:
|
||||
<<: *default
|
||||
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
|
||||
depends_on:
|
||||
- dashboard
|
||||
networks: [netbird]
|
||||
volumes:
|
||||
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
|
||||
ports:
|
||||
- $NETBIRD_SIGNAL_PORT:80
|
||||
# # port and command for Let's Encrypt validation
|
||||
# - 443:443
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||
command: [
|
||||
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
|
||||
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
||||
"--log-file", "console"
|
||||
]
|
||||
command: ["--log-file", "console"]
|
||||
|
||||
# Relay
|
||||
relay:
|
||||
<<: *default
|
||||
image: netbirdio/relay:$NETBIRD_RELAY_TAG
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- NB_LOG_LEVEL=info
|
||||
- NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_PORT
|
||||
- NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_ENDPOINT
|
||||
# todo: change to a secure secret
|
||||
- NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
|
||||
ports:
|
||||
- $NETBIRD_RELAY_PORT:$NETBIRD_RELAY_PORT
|
||||
- NB_LOG_LEVEL=info
|
||||
- NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_INTERNAL_PORT
|
||||
- NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_ENDPOINT
|
||||
- NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
|
||||
|
||||
# Management
|
||||
management:
|
||||
<<: *default
|
||||
image: netbirdio/management:$NETBIRD_MANAGEMENT_TAG
|
||||
depends_on:
|
||||
- dashboard
|
||||
networks: [netbird]
|
||||
volumes:
|
||||
- $MGMT_VOLUMENAME:/var/lib/netbird
|
||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
|
||||
- ./management.json:/etc/netbird/management.json
|
||||
ports:
|
||||
- $NETBIRD_MGMT_API_PORT:443 #API port
|
||||
# # command for Let's Encrypt validation without dashboard container
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||
command: [
|
||||
"--port", "443",
|
||||
"--port", "80",
|
||||
"--log-file", "console",
|
||||
"--log-level", "info",
|
||||
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
|
||||
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
|
||||
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
|
||||
]
|
||||
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN",
|
||||
"--idp-sign-key-refresh-enabled"
|
||||
]
|
||||
environment:
|
||||
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN
|
||||
- NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN
|
||||
|
||||
depends_on:
|
||||
zitadel:
|
||||
condition: service_healthy
|
||||
|
||||
# Coturn
|
||||
coturn:
|
||||
<<: *default
|
||||
image: coturn/coturn:$COTURN_TAG
|
||||
#domainname: $TURN_DOMAIN # only needed when TLS is enabled
|
||||
volumes:
|
||||
- ./turnserver.conf:/etc/turnserver.conf:ro
|
||||
# - ./privkey.pem:/etc/coturn/private/privkey.pem:ro
|
||||
# - ./cert.pem:/etc/coturn/certs/cert.pem:ro
|
||||
network_mode: host
|
||||
command:
|
||||
- -c /etc/turnserver.conf
|
||||
|
||||
# Zitadel Identity Provider (single container with SQLite)
|
||||
zitadel:
|
||||
<<: *default
|
||||
image: ghcr.io/zitadel/zitadel:$ZITADEL_TAG
|
||||
networks: [netbird]
|
||||
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
||||
environment:
|
||||
- ZITADEL_LOG_LEVEL=info
|
||||
- ZITADEL_MASTERKEY=$ZITADEL_MASTERKEY
|
||||
- ZITADEL_EXTERNALSECURE=$ZITADEL_EXTERNALSECURE
|
||||
- ZITADEL_TLS_ENABLED=false
|
||||
- ZITADEL_EXTERNALPORT=$ZITADEL_EXTERNALPORT
|
||||
- ZITADEL_EXTERNALDOMAIN=$NETBIRD_DOMAIN
|
||||
# SQLite database (no separate PostgreSQL container needed)
|
||||
- ZITADEL_DATABASE_SQLITE_PATH=/data/zitadel.db
|
||||
# First instance: Admin user
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_USERNAME=$ZITADEL_ADMIN_USERNAME
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_PASSWORD=$ZITADEL_ADMIN_PASSWORD
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_PASSWORDCHANGEREQUIRED=true
|
||||
# First instance: Service account for NetBird management
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_USERNAME=netbird-service
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_NAME=NetBird Service Account
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINEKEY_TYPE=1
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_EXPIRATIONDATE=$ZITADEL_PAT_EXPIRATION
|
||||
volumes:
|
||||
- netbird-zitadel-data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "/app/zitadel", "ready"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 30s
|
||||
|
||||
volumes:
|
||||
$MGMT_VOLUMENAME:
|
||||
$SIGNAL_VOLUMENAME:
|
||||
$LETSENCRYPT_VOLUMENAME:
|
||||
netbird-caddy-data:
|
||||
netbird-zitadel-data:
|
||||
|
||||
networks:
|
||||
netbird:
|
||||
|
||||
@@ -45,18 +45,18 @@
|
||||
"Engine": "$NETBIRD_STORE_CONFIG_ENGINE"
|
||||
},
|
||||
"HttpConfig": {
|
||||
"Address": "0.0.0.0:$NETBIRD_MGMT_API_PORT",
|
||||
"Address": "0.0.0.0:80",
|
||||
"AuthIssuer": "$NETBIRD_AUTH_AUTHORITY",
|
||||
"AuthAudience": "$NETBIRD_AUTH_AUDIENCE",
|
||||
"AuthAudience": "$NETBIRD_AUTH_CLIENT_ID",
|
||||
"AuthKeysLocation": "$NETBIRD_AUTH_JWT_CERTS",
|
||||
"AuthUserIDClaim": "$NETBIRD_AUTH_USER_ID_CLAIM",
|
||||
"CertFile":"$NETBIRD_MGMT_API_CERT_FILE",
|
||||
"CertKey":"$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
||||
"IdpSignKeyRefreshEnabled": $NETBIRD_MGMT_IDP_SIGNKEY_REFRESH,
|
||||
"OIDCConfigEndpoint":"$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT"
|
||||
"CertFile": "",
|
||||
"CertKey": "",
|
||||
"IdpSignKeyRefreshEnabled": true,
|
||||
"OIDCConfigEndpoint": "$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT"
|
||||
},
|
||||
"IdpManagerConfig": {
|
||||
"ManagerType": "$NETBIRD_MGMT_IDP",
|
||||
"ManagerType": "zitadel",
|
||||
"ClientConfig": {
|
||||
"Issuer": "$NETBIRD_AUTH_AUTHORITY",
|
||||
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
||||
@@ -64,40 +64,28 @@
|
||||
"ClientSecret": "$NETBIRD_IDP_MGMT_CLIENT_SECRET",
|
||||
"GrantType": "client_credentials"
|
||||
},
|
||||
"ExtraConfig": $NETBIRD_IDP_MGMT_EXTRA_CONFIG,
|
||||
"Auth0ClientCredentials": null,
|
||||
"AzureClientCredentials": null,
|
||||
"KeycloakClientCredentials": null,
|
||||
"ZitadelClientCredentials": null
|
||||
},
|
||||
"ExtraConfig": {
|
||||
"ManagementEndpoint": "$ZITADEL_MANAGEMENT_ENDPOINT"
|
||||
}
|
||||
},
|
||||
"DeviceAuthorizationFlow": {
|
||||
"Provider": "$NETBIRD_AUTH_DEVICE_AUTH_PROVIDER",
|
||||
"Provider": "hosted",
|
||||
"ProviderConfig": {
|
||||
"Audience": "$NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE",
|
||||
"AuthorizationEndpoint": "",
|
||||
"Domain": "$NETBIRD_AUTH0_DOMAIN",
|
||||
"ClientID": "$NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID",
|
||||
"ClientSecret": "",
|
||||
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
||||
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT",
|
||||
"Scope": "$NETBIRD_AUTH_DEVICE_AUTH_SCOPE",
|
||||
"UseIDToken": $NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN,
|
||||
"RedirectURLs": null
|
||||
}
|
||||
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
||||
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT",
|
||||
"Scope": "openid"
|
||||
}
|
||||
},
|
||||
"PKCEAuthorizationFlow": {
|
||||
"ProviderConfig": {
|
||||
"Audience": "$NETBIRD_AUTH_PKCE_AUDIENCE",
|
||||
"ClientID": "$NETBIRD_AUTH_CLIENT_ID",
|
||||
"ClientSecret": "$NETBIRD_AUTH_CLIENT_SECRET",
|
||||
"Domain": "",
|
||||
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||
"AuthorizationEndpoint": "$NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT",
|
||||
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
||||
"Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES",
|
||||
"RedirectURLs": [$NETBIRD_AUTH_PKCE_REDIRECT_URLS],
|
||||
"UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN,
|
||||
"DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN,
|
||||
"LoginFlag": $NETBIRD_AUTH_PKCE_LOGIN_FLAG
|
||||
"Scope": "openid profile email offline_access",
|
||||
"RedirectURLs": ["http://localhost:53000/", "http://localhost:54000/"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,117 +1,65 @@
|
||||
## example file, you can copy this file to setup.env and update its values
|
||||
##
|
||||
# NetBird Self-Hosted Setup Configuration
|
||||
# Copy this file to setup.env and configure the required values
|
||||
|
||||
# Image tags
|
||||
# you can force specific tags for each component; will be set to latest if empty
|
||||
# -------------------------------------------
|
||||
# Required: Domain Configuration
|
||||
# -------------------------------------------
|
||||
# Your NetBird domain (e.g., netbird.mydomain.com)
|
||||
NETBIRD_DOMAIN=""
|
||||
|
||||
# -------------------------------------------
|
||||
# Optional: Image Tags
|
||||
# -------------------------------------------
|
||||
# Leave empty to use 'latest' for all components
|
||||
NETBIRD_DASHBOARD_TAG=""
|
||||
NETBIRD_SIGNAL_TAG=""
|
||||
NETBIRD_MANAGEMENT_TAG=""
|
||||
COTURN_TAG=""
|
||||
NETBIRD_RELAY_TAG=""
|
||||
|
||||
# Dashboard domain. e.g. app.mydomain.com
|
||||
NETBIRD_DOMAIN=""
|
||||
# Zitadel version (default: v2.64.1)
|
||||
ZITADEL_TAG=""
|
||||
|
||||
# TURN server domain. e.g. turn.mydomain.com
|
||||
# if not specified it will assume NETBIRD_DOMAIN
|
||||
# -------------------------------------------
|
||||
# Optional: TURN Server Configuration
|
||||
# -------------------------------------------
|
||||
# TURN server domain (defaults to NETBIRD_DOMAIN)
|
||||
NETBIRD_TURN_DOMAIN=""
|
||||
|
||||
# TURN server public IP address
|
||||
# required for a connection involving peers in
|
||||
# the same network as the server and external peers
|
||||
# usually matches the IP for the domain set in NETBIRD_TURN_DOMAIN
|
||||
# Required for peers behind NAT to connect
|
||||
NETBIRD_TURN_EXTERNAL_IP=""
|
||||
|
||||
# -------------------------------------------
|
||||
# OIDC
|
||||
# e.g., https://example.eu.auth0.com/.well-known/openid-configuration
|
||||
# Optional: Database Configuration
|
||||
# -------------------------------------------
|
||||
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
|
||||
# The default setting is to transmit the audience to the IDP during authorization. However,
|
||||
# if your IDP does not have this capability, you can turn this off by setting it to false.
|
||||
#NETBIRD_DASH_AUTH_USE_AUDIENCE=false
|
||||
NETBIRD_AUTH_AUDIENCE=""
|
||||
# e.g. netbird-client
|
||||
NETBIRD_AUTH_CLIENT_ID=""
|
||||
# indicates the scopes that will be requested to the IDP
|
||||
NETBIRD_AUTH_SUPPORTED_SCOPES=""
|
||||
# NETBIRD_AUTH_CLIENT_SECRET is required only by Google workspace.
|
||||
# NETBIRD_AUTH_CLIENT_SECRET=""
|
||||
# if you want to use a custom claim for the user ID instead of 'sub', set it here
|
||||
# NETBIRD_AUTH_USER_ID_CLAIM=""
|
||||
# indicates whether to use Auth0 or not: true or false
|
||||
NETBIRD_USE_AUTH0="false"
|
||||
# if your IDP provider doesn't support fragmented URIs, configure custom
|
||||
# redirect and silent redirect URIs, these will be concatenated into your NETBIRD_DOMAIN domain.
|
||||
# NETBIRD_AUTH_REDIRECT_URI="/peers"
|
||||
# NETBIRD_AUTH_SILENT_REDIRECT_URI="/add-peers"
|
||||
# Updates the preference to use id tokens instead of access token on dashboard
|
||||
# Okta and Gitlab IDPs can benefit from this
|
||||
# NETBIRD_TOKEN_SOURCE="idToken"
|
||||
# Store engine: sqlite (default), postgres, or mysql
|
||||
NETBIRD_STORE_CONFIG_ENGINE=""
|
||||
|
||||
# For PostgreSQL:
|
||||
# NETBIRD_STORE_ENGINE_POSTGRES_DSN="host=<HOST> user=<USER> password=<PASS> dbname=<DB> port=5432"
|
||||
|
||||
# For MySQL:
|
||||
# NETBIRD_STORE_ENGINE_MYSQL_DSN="<user>:<pass>@tcp(127.0.0.1:3306)/<db>"
|
||||
|
||||
# -------------------------------------------
|
||||
# OIDC Device Authorization Flow
|
||||
# Optional: Extra Settings
|
||||
# -------------------------------------------
|
||||
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"
|
||||
NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID=""
|
||||
# Some IDPs requires different audience, scopes and to use id token for device authorization flow
|
||||
# you can customize here:
|
||||
NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
|
||||
NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid"
|
||||
NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN=false
|
||||
# -------------------------------------------
|
||||
# OIDC PKCE Authorization Flow
|
||||
# -------------------------------------------
|
||||
# Comma separated port numbers. if already in use, PKCE flow will choose an available port from the list as an alternative
|
||||
# eg. 53000,54000
|
||||
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS="53000"
|
||||
# -------------------------------------------
|
||||
# IDP Management
|
||||
# -------------------------------------------
|
||||
# eg. zitadel, auth0, azure, keycloak
|
||||
NETBIRD_MGMT_IDP="none"
|
||||
# Some IDPs requires different client id and client secret for management api
|
||||
NETBIRD_IDP_MGMT_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
|
||||
NETBIRD_IDP_MGMT_CLIENT_SECRET=""
|
||||
# Required when setting up with Keycloak "https://<YOUR_KEYCLOAK_HOST_AND_PORT>/admin/realms/netbird"
|
||||
# NETBIRD_IDP_MGMT_EXTRA_ADMIN_ENDPOINT=
|
||||
# With some IDPs may be needed enabling automatic refresh of signing keys on expire
|
||||
# NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=false
|
||||
# NETBIRD_IDP_MGMT_EXTRA_ variables. See https://docs.netbird.io/selfhosted/identity-providers for more information about your IDP of choice.
|
||||
# -------------------------------------------
|
||||
# Letsencrypt
|
||||
# -------------------------------------------
|
||||
# Disable letsencrypt
|
||||
# if disabled, cannot use HTTPS anymore and requires setting up a reverse-proxy to do it instead
|
||||
NETBIRD_DISABLE_LETSENCRYPT=false
|
||||
# e.g. hello@mydomain.com
|
||||
NETBIRD_LETSENCRYPT_EMAIL=""
|
||||
# -------------------------------------------
|
||||
# Extra settings
|
||||
# -------------------------------------------
|
||||
# Disable anonymous metrics collection, see more information at https://netbird.io/docs/FAQ/metrics-collection
|
||||
# Disable anonymous metrics (default: false)
|
||||
NETBIRD_DISABLE_ANONYMOUS_METRICS=false
|
||||
# DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted
|
||||
|
||||
# DNS domain for peer resolution (default: netbird.selfhosted)
|
||||
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted
|
||||
# Disable default all-to-all policy for new accounts
|
||||
|
||||
# Disable default all-to-all policy for new accounts (default: false)
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=false
|
||||
|
||||
# -------------------------------------------
|
||||
# Relay settings
|
||||
# Advanced: Zitadel Client IDs
|
||||
# -------------------------------------------
|
||||
# Relay server domain. e.g. relay.mydomain.com
|
||||
# if not specified it will assume NETBIRD_DOMAIN
|
||||
NETBIRD_RELAY_DOMAIN=""
|
||||
|
||||
# Relay server connection port. If none is supplied
|
||||
# it will default to 33080
|
||||
# should be updated to match TLS-port of reverse proxy when netbird is running behind reverse proxy
|
||||
NETBIRD_RELAY_PORT=""
|
||||
|
||||
# Management API connecting port. If none is supplied
|
||||
# it will default to 33073
|
||||
# should be updated to match TLS-port of reverse proxy when netbird is running behind reverse proxy
|
||||
NETBIRD_MGMT_API_PORT=""
|
||||
|
||||
# Signal service connecting port. If none is supplied
|
||||
# it will default to 10000
|
||||
# should be updated to match TLS-port of reverse proxy when netbird is running behind reverse proxy
|
||||
NETBIRD_SIGNAL_PORT=""
|
||||
# These are auto-generated by Zitadel on first boot
|
||||
# Only set these if migrating from an existing Zitadel setup
|
||||
# NETBIRD_AUTH_CLIENT_ID=""
|
||||
# NETBIRD_AUTH_CLIENT_ID_CLI=""
|
||||
# NETBIRD_IDP_MGMT_CLIENT_ID=""
|
||||
# NETBIRD_IDP_MGMT_CLIENT_SECRET=""
|
||||
|
||||
@@ -587,42 +587,40 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context, store cache
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
userData, err := am.idpManager.GetAllAccounts(ctx)
|
||||
// Get all users from IdP
|
||||
idpUsers, err := am.idpManager.GetAllUsers(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.WithContext(ctx).Infof("%d entries received from IdP management", len(userData))
|
||||
log.WithContext(ctx).Infof("%d users received from IdP management", len(idpUsers))
|
||||
|
||||
// If the Identity Provider does not support writing AppMetadata,
|
||||
// in cases like this, we expect it to return all users in an "unset" field.
|
||||
// We iterate over the users in the "unset" field, look up their AccountID in our store, and
|
||||
// update their AppMetadata with the AccountID.
|
||||
if unsetData, ok := userData[idp.UnsetAccountID]; ok {
|
||||
for _, user := range unsetData {
|
||||
accountID, err := am.Store.GetAccountByUser(ctx, user.ID)
|
||||
if err == nil {
|
||||
data := userData[accountID.Id]
|
||||
if data == nil {
|
||||
data = make([]*idp.UserData, 0, 1)
|
||||
}
|
||||
|
||||
user.AppMetadata.WTAccountID = accountID.Id
|
||||
|
||||
userData[accountID.Id] = append(data, user)
|
||||
}
|
||||
}
|
||||
// Create a map for quick lookup of IdP users by ID
|
||||
idpUserMap := make(map[string]*idp.UserData, len(idpUsers))
|
||||
for _, user := range idpUsers {
|
||||
idpUserMap[user.ID] = user
|
||||
}
|
||||
|
||||
// Group IdP users by their account ID from NetBird's database
|
||||
// NetBird DB is the source of truth for account membership
|
||||
accountUsers := make(map[string][]*idp.UserData)
|
||||
for _, idpUser := range idpUsers {
|
||||
account, err := am.Store.GetAccountByUser(ctx, idpUser.ID)
|
||||
if err != nil {
|
||||
// User exists in IdP but not in NetBird - skip
|
||||
continue
|
||||
}
|
||||
accountUsers[account.Id] = append(accountUsers[account.Id], idpUser)
|
||||
}
|
||||
delete(userData, idp.UnsetAccountID)
|
||||
|
||||
rcvdUsers := 0
|
||||
for accountID, users := range userData {
|
||||
for accountID, users := range accountUsers {
|
||||
rcvdUsers += len(users)
|
||||
err = am.cacheManager.Set(am.ctx, accountID, users, cacheEntryExpiration())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
log.WithContext(ctx).Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData))
|
||||
log.WithContext(ctx).Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(accountUsers))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -742,10 +740,6 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
|
||||
if err != nil {
|
||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||
}
|
||||
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
}
|
||||
return "", err
|
||||
@@ -757,35 +751,6 @@ func isNil(i idp.Manager) bool {
|
||||
return i == nil || reflect.ValueOf(i).IsNil()
|
||||
}
|
||||
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) {
|
||||
// user can be nil if it wasn't found (e.g., just created)
|
||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user != nil && user.AppMetadata.WTAccountID == accountID {
|
||||
// it was already set, so we skip the unnecessary update
|
||||
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
|
||||
accountID, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
|
||||
}
|
||||
// refresh cache to reflect the update
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) (any, []cacheStore.Option, error) {
|
||||
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
|
||||
accountIDString := fmt.Sprintf("%v", accountID)
|
||||
@@ -797,28 +762,32 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
|
||||
|
||||
// Get users belonging to this account from NetBird's database
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userData, err := am.idpManager.GetAccount(ctx, accountIDString)
|
||||
// Get all users from IdP (we don't have per-account queries anymore)
|
||||
idpUsers, err := am.idpManager.GetAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("%d entries received from IdP management for account %s", len(userData), accountIDString)
|
||||
log.WithContext(ctx).Debugf("%d total users in IdP, matching against account %s", len(idpUsers), accountIDString)
|
||||
|
||||
dataMap := make(map[string]*idp.UserData, len(userData))
|
||||
for _, datum := range userData {
|
||||
dataMap[datum.ID] = datum
|
||||
// Create a map for quick lookup of IdP users by ID
|
||||
idpUserMap := make(map[string]*idp.UserData, len(idpUsers))
|
||||
for _, user := range idpUsers {
|
||||
idpUserMap[user.ID] = user
|
||||
}
|
||||
|
||||
// Match account users against IdP users
|
||||
matchedUserData := make([]*idp.UserData, 0)
|
||||
for _, user := range accountUsers {
|
||||
if user.IsServiceUser {
|
||||
continue
|
||||
}
|
||||
datum, ok := dataMap[user.Id]
|
||||
datum, ok := idpUserMap[user.Id]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Warnf("user %s not found in IDP", user.Id)
|
||||
continue
|
||||
@@ -972,7 +941,7 @@ func (am *DefaultAccountManager) lookupCache(ctx context.Context, accountUsers m
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count and user invite status
|
||||
// isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count
|
||||
func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool {
|
||||
userDataMap := make(map[string]*idp.UserData, len(data))
|
||||
for _, datum := range data {
|
||||
@@ -980,15 +949,10 @@ func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers
|
||||
}
|
||||
|
||||
// the accountUsers ID list of non integration users from store, we check if cache has all of them
|
||||
// as result of for loop knownUsersCount will have number of users are not presented in the cashed
|
||||
// as result of for loop knownUsersCount will have number of users are not presented in the cached
|
||||
knownUsersCount := len(accountUsers)
|
||||
for user, loggedInOnce := range accountUsers {
|
||||
if datum, ok := userDataMap[user]; ok {
|
||||
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
|
||||
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
|
||||
log.WithContext(ctx).Infof("user %s has a pending invite and has logged in once, cache invalid", user)
|
||||
return false
|
||||
}
|
||||
for user := range accountUsers {
|
||||
if _, ok := userDataMap[user]; ok {
|
||||
knownUsersCount--
|
||||
continue
|
||||
}
|
||||
@@ -1078,12 +1042,6 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
return err
|
||||
}
|
||||
|
||||
// we should register the account ID to this user's metadata in our IDP manager
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, userAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1110,11 +1068,6 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, newAccount.Id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, newAccount.Id, activity.UserJoined, nil)
|
||||
|
||||
return newAccount.Id, nil
|
||||
@@ -1139,11 +1092,6 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if newUser.PendingApproval {
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, map[string]any{"pending_approval": true})
|
||||
} else {
|
||||
@@ -1155,34 +1103,30 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
|
||||
|
||||
// redeemInvite checks whether user has been invited and redeems the invite
|
||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
|
||||
// only possible with the enabled IdP manager
|
||||
if am.idpManager == nil {
|
||||
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||
// Get user from NetBird's database (source of truth for authorization data)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return status.Errorf(status.NotFound, "user %s not found in the IdP", userID)
|
||||
return status.Errorf(status.NotFound, "user %s not found", userID)
|
||||
}
|
||||
|
||||
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
|
||||
// Check if user has a pending invite that needs to be redeemed
|
||||
if user.PendingInvite {
|
||||
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID)
|
||||
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
|
||||
// Our job is to just reload cache.
|
||||
go func() {
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID)
|
||||
return
|
||||
}
|
||||
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID)
|
||||
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
|
||||
}()
|
||||
|
||||
// Update user to mark invite as redeemed
|
||||
user.PendingInvite = false
|
||||
err = am.Store.SaveUser(ctx, user)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to redeem invite for user %s: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", userID, accountID)
|
||||
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/connectors"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/events"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/groups"
|
||||
@@ -134,6 +135,7 @@ func NewAPIHandler(
|
||||
dns.AddEndpoints(accountManager, router)
|
||||
events.AddEndpoints(accountManager, router)
|
||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
|
||||
connectors.AddEndpoints(accountManager, router)
|
||||
|
||||
return rootRouter, nil
|
||||
}
|
||||
|
||||
590
management/server/http/handlers/connectors/connectors_handler.go
Normal file
590
management/server/http/handlers/connectors/connectors_handler.go
Normal file
@@ -0,0 +1,590 @@
|
||||
package connectors
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// API request/response types
|
||||
|
||||
// ConnectorResponse represents an IdP connector in API responses
|
||||
type ConnectorResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
State string `json:"state"`
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
Servers []string `json:"servers,omitempty"`
|
||||
}
|
||||
|
||||
// OIDCConnectorRequest represents a request to create an OIDC connector
|
||||
type OIDCConnectorRequest struct {
|
||||
Name string `json:"name"`
|
||||
Issuer string `json:"issuer"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
IsAutoCreation bool `json:"is_auto_creation,omitempty"`
|
||||
IsAutoUpdate bool `json:"is_auto_update,omitempty"`
|
||||
IsCreationAllowed bool `json:"is_creation_allowed,omitempty"`
|
||||
IsLinkingAllowed bool `json:"is_linking_allowed,omitempty"`
|
||||
}
|
||||
|
||||
// LDAPConnectorRequest represents a request to create an LDAP connector
|
||||
type LDAPConnectorRequest struct {
|
||||
Name string `json:"name"`
|
||||
Servers []string `json:"servers"`
|
||||
StartTLS bool `json:"start_tls,omitempty"`
|
||||
BaseDN string `json:"base_dn"`
|
||||
BindDN string `json:"bind_dn"`
|
||||
BindPassword string `json:"bind_password"`
|
||||
UserBase string `json:"user_base,omitempty"`
|
||||
UserObjectClass []string `json:"user_object_class,omitempty"`
|
||||
UserFilters []string `json:"user_filters,omitempty"`
|
||||
Timeout string `json:"timeout,omitempty"`
|
||||
Attributes *LDAPAttributesRequest `json:"attributes,omitempty"`
|
||||
IsAutoCreation bool `json:"is_auto_creation,omitempty"`
|
||||
IsAutoUpdate bool `json:"is_auto_update,omitempty"`
|
||||
IsCreationAllowed bool `json:"is_creation_allowed,omitempty"`
|
||||
IsLinkingAllowed bool `json:"is_linking_allowed,omitempty"`
|
||||
}
|
||||
|
||||
// LDAPAttributesRequest maps LDAP attributes to user fields
|
||||
type LDAPAttributesRequest struct {
|
||||
IDAttribute string `json:"id_attribute,omitempty"`
|
||||
FirstNameAttribute string `json:"first_name_attribute,omitempty"`
|
||||
LastNameAttribute string `json:"last_name_attribute,omitempty"`
|
||||
DisplayNameAttribute string `json:"display_name_attribute,omitempty"`
|
||||
EmailAttribute string `json:"email_attribute,omitempty"`
|
||||
}
|
||||
|
||||
// SAMLConnectorRequest represents a request to create a SAML connector
|
||||
type SAMLConnectorRequest struct {
|
||||
Name string `json:"name"`
|
||||
MetadataXML string `json:"metadata_xml,omitempty"`
|
||||
MetadataURL string `json:"metadata_url,omitempty"`
|
||||
Binding string `json:"binding,omitempty"`
|
||||
WithSignedRequest bool `json:"with_signed_request,omitempty"`
|
||||
NameIDFormat string `json:"name_id_format,omitempty"`
|
||||
IsAutoCreation bool `json:"is_auto_creation,omitempty"`
|
||||
IsAutoUpdate bool `json:"is_auto_update,omitempty"`
|
||||
IsCreationAllowed bool `json:"is_creation_allowed,omitempty"`
|
||||
IsLinkingAllowed bool `json:"is_linking_allowed,omitempty"`
|
||||
}
|
||||
|
||||
// handler handles HTTP requests for IdP connectors
|
||||
type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
// AddEndpoints registers the connector endpoints to the router
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
h := &handler{accountManager: accountManager}
|
||||
|
||||
router.HandleFunc("/connectors", h.listConnectors).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/connectors/{connectorId}", h.getConnector).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/connectors/{connectorId}", h.deleteConnector).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/connectors/oidc", h.addOIDCConnector).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/connectors/ldap", h.addLDAPConnector).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/connectors/saml", h.addSAMLConnector).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/connectors/{connectorId}/activate", h.activateConnector).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/connectors/{connectorId}/deactivate", h.deleteConnector).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// getConnectorManager retrieves the connector manager from the IdP manager
|
||||
func (h *handler) getConnectorManager() (idp.ConnectorManager, error) {
|
||||
idpManager := h.accountManager.GetIdpManager()
|
||||
if idpManager == nil {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "IdP manager is not configured")
|
||||
}
|
||||
|
||||
connectorManager, ok := idpManager.(idp.ConnectorManager)
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "IdP manager does not support connector management")
|
||||
}
|
||||
|
||||
return connectorManager, nil
|
||||
}
|
||||
|
||||
// listConnectors returns all configured IdP connectors
|
||||
func (h *handler) listConnectors(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
// Only admins can manage connectors
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||
return
|
||||
}
|
||||
|
||||
connectorManager, err := h.getConnectorManager()
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
connectors, err := connectorManager.ListConnectors(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to list connectors: %v", err), w)
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]*ConnectorResponse, 0, len(connectors))
|
||||
for _, c := range connectors {
|
||||
response = append(response, toConnectorResponse(c))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, response)
|
||||
}
|
||||
|
||||
// getConnector returns a specific connector by ID
|
||||
func (h *handler) getConnector(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
connectorID := vars["connectorId"]
|
||||
if connectorID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "connector ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
connectorManager, err := h.getConnectorManager()
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
connector, err := connectorManager.GetConnector(r.Context(), connectorID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "connector not found: %v", err), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toConnectorResponse(connector))
|
||||
}
|
||||
|
||||
// deleteConnector removes a connector
|
||||
func (h *handler) deleteConnector(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
connectorID := vars["connectorId"]
|
||||
if connectorID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "connector ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
connectorManager, err := h.getConnectorManager()
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := connectorManager.DeleteConnector(r.Context(), connectorID); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to delete connector: %v", err), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// addOIDCConnector creates a new OIDC connector
|
||||
func (h *handler) addOIDCConnector(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req OIDCConnectorRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "name is required"), w)
|
||||
return
|
||||
}
|
||||
if req.Issuer == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "issuer is required"), w)
|
||||
return
|
||||
}
|
||||
if req.ClientID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "client_id is required"), w)
|
||||
return
|
||||
}
|
||||
if req.ClientSecret == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "client_secret is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
connectorManager, err := h.getConnectorManager()
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
config := idp.OIDCConnectorConfig{
|
||||
Name: req.Name,
|
||||
Issuer: req.Issuer,
|
||||
ClientID: req.ClientID,
|
||||
ClientSecret: req.ClientSecret,
|
||||
Scopes: req.Scopes,
|
||||
IsAutoCreation: req.IsAutoCreation,
|
||||
IsAutoUpdate: req.IsAutoUpdate,
|
||||
IsCreationAllowed: req.IsCreationAllowed,
|
||||
IsLinkingAllowed: req.IsLinkingAllowed,
|
||||
}
|
||||
|
||||
connector, err := connectorManager.AddOIDCConnector(r.Context(), config)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to add OIDC connector: %v", err), w)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
util.WriteJSONObject(r.Context(), w, toConnectorResponse(connector))
|
||||
}
|
||||
|
||||
// addLDAPConnector creates a new LDAP connector
|
||||
func (h *handler) addLDAPConnector(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req LDAPConnectorRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "name is required"), w)
|
||||
return
|
||||
}
|
||||
if len(req.Servers) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "at least one server is required"), w)
|
||||
return
|
||||
}
|
||||
if req.BaseDN == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "base_dn is required"), w)
|
||||
return
|
||||
}
|
||||
if req.BindDN == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "bind_dn is required"), w)
|
||||
return
|
||||
}
|
||||
if req.BindPassword == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "bind_password is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
connectorManager, err := h.getConnectorManager()
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
config := idp.LDAPConnectorConfig{
|
||||
Name: req.Name,
|
||||
Servers: req.Servers,
|
||||
StartTLS: req.StartTLS,
|
||||
BaseDN: req.BaseDN,
|
||||
BindDN: req.BindDN,
|
||||
BindPassword: req.BindPassword,
|
||||
UserBase: req.UserBase,
|
||||
UserObjectClass: req.UserObjectClass,
|
||||
UserFilters: req.UserFilters,
|
||||
Timeout: req.Timeout,
|
||||
IsAutoCreation: req.IsAutoCreation,
|
||||
IsAutoUpdate: req.IsAutoUpdate,
|
||||
IsCreationAllowed: req.IsCreationAllowed,
|
||||
IsLinkingAllowed: req.IsLinkingAllowed,
|
||||
}
|
||||
|
||||
if req.Attributes != nil {
|
||||
config.Attributes = idp.LDAPAttributes{
|
||||
IDAttribute: req.Attributes.IDAttribute,
|
||||
FirstNameAttribute: req.Attributes.FirstNameAttribute,
|
||||
LastNameAttribute: req.Attributes.LastNameAttribute,
|
||||
DisplayNameAttribute: req.Attributes.DisplayNameAttribute,
|
||||
EmailAttribute: req.Attributes.EmailAttribute,
|
||||
}
|
||||
}
|
||||
|
||||
connector, err := connectorManager.AddLDAPConnector(r.Context(), config)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to add LDAP connector: %v", err), w)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
util.WriteJSONObject(r.Context(), w, toConnectorResponse(connector))
|
||||
}
|
||||
|
||||
// addSAMLConnector creates a new SAML connector
|
||||
func (h *handler) addSAMLConnector(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req SAMLConnectorRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "name is required"), w)
|
||||
return
|
||||
}
|
||||
if req.MetadataXML == "" && req.MetadataURL == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "either metadata_xml or metadata_url is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
connectorManager, err := h.getConnectorManager()
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
config := idp.SAMLConnectorConfig{
|
||||
Name: req.Name,
|
||||
MetadataXML: req.MetadataXML,
|
||||
MetadataURL: req.MetadataURL,
|
||||
Binding: req.Binding,
|
||||
WithSignedRequest: req.WithSignedRequest,
|
||||
NameIDFormat: req.NameIDFormat,
|
||||
IsAutoCreation: req.IsAutoCreation,
|
||||
IsAutoUpdate: req.IsAutoUpdate,
|
||||
IsCreationAllowed: req.IsCreationAllowed,
|
||||
IsLinkingAllowed: req.IsLinkingAllowed,
|
||||
}
|
||||
|
||||
connector, err := connectorManager.AddSAMLConnector(r.Context(), config)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to add SAML connector: %v", err), w)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
util.WriteJSONObject(r.Context(), w, toConnectorResponse(connector))
|
||||
}
|
||||
|
||||
// activateConnector adds the connector to the login policy
|
||||
func (h *handler) activateConnector(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
connectorID := vars["connectorId"]
|
||||
if connectorID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "connector ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
connectorManager, err := h.getConnectorManager()
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := connectorManager.ActivateConnector(r.Context(), connectorID); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to activate connector: %v", err), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// deactivateConnector removes the connector from the login policy
|
||||
func (h *handler) deactivateConnector(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
connectorID := vars["connectorId"]
|
||||
if connectorID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "connector ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
connectorManager, err := h.getConnectorManager()
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := connectorManager.DeactivateConnector(r.Context(), connectorID); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to deactivate connector: %v", err), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// toConnectorResponse converts an idp.Connector to a ConnectorResponse
|
||||
func toConnectorResponse(c *idp.Connector) *ConnectorResponse {
|
||||
return &ConnectorResponse{
|
||||
ID: c.ID,
|
||||
Name: c.Name,
|
||||
Type: string(c.Type),
|
||||
State: c.State,
|
||||
Issuer: c.Issuer,
|
||||
Servers: c.Servers,
|
||||
}
|
||||
}
|
||||
@@ -1,959 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Auth0Manager auth0 manager client instance
|
||||
type Auth0Manager struct {
|
||||
authIssuer string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// Auth0ClientConfig auth0 manager client configurations
|
||||
type Auth0ClientConfig struct {
|
||||
Audience string
|
||||
AuthIssuer string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// auth0JWTRequest payload struct to request a JWT Token
|
||||
type auth0JWTRequest struct {
|
||||
Audience string `json:"audience"`
|
||||
AuthIssuer string `json:"auth_issuer"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
GrantType string `json:"grant_type"`
|
||||
}
|
||||
|
||||
// Auth0Credentials auth0 authentication information
|
||||
type Auth0Credentials struct {
|
||||
clientConfig Auth0ClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// createUserRequest is a user create request
|
||||
type createUserRequest struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
AppMeta AppMetadata `json:"app_metadata"`
|
||||
Connection string `json:"connection"`
|
||||
Password string `json:"password"`
|
||||
VerifyEmail bool `json:"verify_email"`
|
||||
}
|
||||
|
||||
// userExportJobRequest is a user export request struct
|
||||
type userExportJobRequest struct {
|
||||
Format string `json:"format"`
|
||||
Fields []map[string]string `json:"fields"`
|
||||
}
|
||||
|
||||
// userExportJobResponse is a user export response struct
|
||||
type userExportJobResponse struct {
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
ConnectionID string `json:"connection_id"`
|
||||
Format string `json:"format"`
|
||||
Limit int `json:"limit"`
|
||||
Connection string `json:"connection"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// userExportJobStatusResponse is a user export status response struct
|
||||
type userExportJobStatusResponse struct {
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
ConnectionID string `json:"connection_id"`
|
||||
Format string `json:"format"`
|
||||
Limit int `json:"limit"`
|
||||
Location string `json:"location"`
|
||||
Connection string `json:"connection"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// userVerificationJobRequest is a user verification request struct
|
||||
type userVerificationJobRequest struct {
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// auth0Profile represents an Auth0 user profile response
|
||||
type auth0Profile struct {
|
||||
AccountID string `json:"wt_account_id"`
|
||||
PendingInvite bool `json:"wt_pending_invite"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastLogin string `json:"last_login"`
|
||||
}
|
||||
|
||||
// Connections represents a single Auth0 connection
|
||||
// https://auth0.com/docs/api/management/v2/connections/get-connections
|
||||
type Connection struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"display_name"`
|
||||
IsDomainConnection bool `json:"is_domain_connection"`
|
||||
Realms []string `json:"realms"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
Options ConnectionOptions `json:"options"`
|
||||
}
|
||||
|
||||
type ConnectionOptions struct {
|
||||
DomainAliases []string `json:"domain_aliases"`
|
||||
}
|
||||
|
||||
// NewAuth0Manager creates a new instance of the Auth0Manager
|
||||
func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.AuthIssuer == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, AuthIssuer is missing")
|
||||
}
|
||||
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientID is missing")
|
||||
}
|
||||
|
||||
if config.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientSecret is missing")
|
||||
}
|
||||
|
||||
if config.Audience == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, Audience is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
credentials := &Auth0Credentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &Auth0Manager{
|
||||
authIssuer: config.AuthIssuer,
|
||||
credentials: credentials,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from Auth0
|
||||
func (c *Auth0Credentials) jwtStillValid() bool {
|
||||
return !c.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(c.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token
|
||||
func (c *Auth0Credentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
var res *http.Response
|
||||
reqURL := c.clientConfig.AuthIssuer + "/oauth/token"
|
||||
|
||||
p, err := c.helper.Marshal(auth0JWTRequest(c.clientConfig))
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
payload := strings.NewReader(string(p))
|
||||
|
||||
req, err := http.NewRequest("POST", reqURL, payload)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for idp manager")
|
||||
|
||||
res, err = c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if c.appMetrics != nil {
|
||||
c.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return res, fmt.Errorf("unable to get token, statusCode %d", res.StatusCode)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = c.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = json.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the Auth0 Management API
|
||||
func (c *Auth0Credentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if c.appMetrics != nil {
|
||||
c.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// If jwtToken has an expires time and we have enough time to do a request return immediately
|
||||
if c.jwtStillValid() {
|
||||
return c.jwtToken, nil
|
||||
}
|
||||
|
||||
res, err := c.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return c.jwtToken, err
|
||||
}
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing get jwt token response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
jwtToken, err := c.parseRequestJWTResponse(res.Body)
|
||||
if err != nil {
|
||||
return c.jwtToken, err
|
||||
}
|
||||
|
||||
c.jwtToken = jwtToken
|
||||
|
||||
return c.jwtToken, nil
|
||||
}
|
||||
|
||||
func batchRequestUsersURL(authIssuer, accountID string, page int, perPage int) (string, url.Values, error) {
|
||||
u, err := url.Parse(authIssuer + "/api/v2/users")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("page", strconv.Itoa(page))
|
||||
q.Set("search_engine", "v3")
|
||||
q.Set("per_page", strconv.Itoa(perPage))
|
||||
q.Set("q", "app_metadata.wt_account_id:"+accountID)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String(), q, nil
|
||||
}
|
||||
|
||||
func requestByUserIDURL(authIssuer, userID string) string {
|
||||
return authIssuer + "/api/v2/users/" + userID
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile. Calls Auth0 API.
|
||||
func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var list []*UserData
|
||||
|
||||
// https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations
|
||||
// auth0 limitation of 1000 users via this endpoint
|
||||
resultsPerPage := 50
|
||||
for page := 0; page < 20; page++ {
|
||||
reqURL, query, err := batchRequestUsersURL(am.authIssuer, accountID, page, resultsPerPage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, strings.NewReader(query.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
res, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed requesting user data from IdP %s", string(body))
|
||||
}
|
||||
|
||||
var batch []UserData
|
||||
err = json.Unmarshal(body, &batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
|
||||
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for user := range batch {
|
||||
list = append(list, &batch[user])
|
||||
}
|
||||
|
||||
if len(batch) == 0 || len(batch) < resultsPerPage {
|
||||
log.WithContext(ctx).Debugf("finished loading users for accountID %s", accountID)
|
||||
return list, nil
|
||||
}
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from auth0 via ID
|
||||
func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := requestByUserIDURL(am.authIssuer, userID)
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
res, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var userData UserData
|
||||
err = json.Unmarshal(body, &userData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("unable to get UserData, statusCode %d", res.StatusCode)
|
||||
}
|
||||
|
||||
return &userData, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userId and metadata map
|
||||
func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error {
|
||||
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqURL := am.authIssuer + "/api/v2/users/" + userID
|
||||
|
||||
data, err := am.helper.Marshal(map[string]any{"app_metadata": appMetadata})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payload := strings.NewReader(string(data))
|
||||
|
||||
req, err := http.NewRequest("PATCH", reqURL, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.WithContext(ctx).Debugf("updating IdP metadata for user %s", userID)
|
||||
|
||||
res, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return fmt.Errorf("unable to update the appMetadata, statusCode %d", res.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildCreateUserRequestPayload(email, name, accountID, invitedByEmail string) (string, error) {
|
||||
invite := true
|
||||
req := &createUserRequest{
|
||||
Email: email,
|
||||
Name: name,
|
||||
AppMeta: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &invite,
|
||||
WTInvitedBy: invitedByEmail,
|
||||
},
|
||||
Connection: "Username-Password-Authentication",
|
||||
Password: GeneratePassword(8, 1, 1, 1),
|
||||
VerifyEmail: true,
|
||||
}
|
||||
|
||||
str, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(str), nil
|
||||
}
|
||||
|
||||
func buildUserExportRequest() (string, error) {
|
||||
req := &userExportJobRequest{}
|
||||
fields := make([]map[string]string, 0)
|
||||
|
||||
for _, field := range []string{"created_at", "last_login", "user_id", "email", "name"} {
|
||||
fields = append(fields, map[string]string{"name": field})
|
||||
}
|
||||
|
||||
fields = append(fields, map[string]string{
|
||||
"name": "app_metadata.wt_account_id",
|
||||
"export_as": "wt_account_id",
|
||||
})
|
||||
|
||||
fields = append(fields, map[string]string{
|
||||
"name": "app_metadata.wt_pending_invite",
|
||||
"export_as": "wt_pending_invite",
|
||||
})
|
||||
|
||||
req.Format = "json"
|
||||
req.Fields = fields
|
||||
|
||||
str, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(str), nil
|
||||
}
|
||||
|
||||
func (am *Auth0Manager) createRequest(
|
||||
ctx context.Context, method string, endpoint string, body io.Reader,
|
||||
) (*http.Request, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := am.authIssuer + endpoint
|
||||
|
||||
req, err := http.NewRequest(method, reqURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (am *Auth0Manager) createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
|
||||
req, err := am.createRequest(ctx, "POST", endpoint, strings.NewReader(payloadStr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
payloadString, err := buildUserExportRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exportJobReq, err := am.createPostRequest(ctx, "/api/v2/jobs/users-exports", payloadString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jobResp, err := am.httpClient.Do(exportJobReq)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = jobResp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if jobResp.StatusCode != 200 {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to update the appMetadata, statusCode %d", jobResp.StatusCode)
|
||||
}
|
||||
|
||||
var exportJobResp userExportJobResponse
|
||||
|
||||
body, err := io.ReadAll(jobResp.Body)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &exportJobResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if exportJobResp.ID == "" {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
||||
|
||||
done, downloadLink, err := am.checkExportJobStatus(ctx, exportJobResp.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed at getting status checks from exportJob; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if done {
|
||||
return am.downloadProfileExport(ctx, downloadLink)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed extracting user profiles from auth0")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list.
|
||||
// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with
|
||||
// the same email but different connections that are considered as separate accounts (e.g., Google and username/password).
|
||||
func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
|
||||
body, err := doGetReq(ctx, am.httpClient, reqURL, jwtToken.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
userResp := []*UserData{}
|
||||
|
||||
err = am.helper.Unmarshal(body, &userResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userResp, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in Auth0 Idp and sends an invite
|
||||
func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
|
||||
payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := am.createPostRequest(ctx, "/api/v2/users", payloadString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing create user response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var createResp UserData
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &createResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if createResp.ID == "" {
|
||||
return nil, fmt.Errorf("couldn't create user: response %v", resp)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("created user %s in account %s", createResp.ID, accountID)
|
||||
|
||||
return &createResp, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error {
|
||||
userVerificationReq := userVerificationJobRequest{
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
payload, err := am.helper.Marshal(userVerificationReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := am.createPostRequest(ctx, "/api/v2/jobs/verification-email", string(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing invite user response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to invite user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser from Auth0
|
||||
func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error {
|
||||
req, err := am.createRequest(ctx, http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("execute delete request: %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("close delete request body: %v", err)
|
||||
}
|
||||
}()
|
||||
if resp.StatusCode != 204 {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllConnections returns detailed list of all connections filtered by given params.
|
||||
// Note this method is not part of the IDP Manager interface as this is Auth0 specific.
|
||||
func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string) ([]Connection, error) {
|
||||
var connections []Connection
|
||||
|
||||
q := make(url.Values)
|
||||
q.Set("strategy", strings.Join(strategy, ","))
|
||||
|
||||
req, err := am.createRequest(ctx, http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
|
||||
if err != nil {
|
||||
return connections, err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("execute get connections request: %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return connections, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("close get connections request body: %v", err)
|
||||
}
|
||||
}()
|
||||
if resp.StatusCode != 200 {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return connections, fmt.Errorf("unable to get connections, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't read get connections response; %v", err)
|
||||
return connections, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &connections)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal get connection response; %v", err)
|
||||
return connections, err
|
||||
}
|
||||
|
||||
return connections, err
|
||||
}
|
||||
|
||||
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
||||
// If the status is "completed", then return the downloadLink
|
||||
func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string) (bool, string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 90*time.Second)
|
||||
defer cancel()
|
||||
retry := time.NewTicker(10 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.WithContext(ctx).Debugf("Export job status stopped...\n")
|
||||
return false, "", ctx.Err()
|
||||
case <-retry.C:
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
statusURL := am.authIssuer + "/api/v2/jobs/" + jobID
|
||||
body, err := doGetReq(ctx, am.httpClient, statusURL, jwtToken.AccessToken)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
var status userExportJobStatusResponse
|
||||
err = am.helper.Unmarshal(body, &status)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("current export job status is %v", status.Status)
|
||||
|
||||
if status.Status != "completed" {
|
||||
continue
|
||||
}
|
||||
|
||||
return true, status.Location, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// downloadProfileExport downloads user profiles from auth0 batch job
|
||||
func (am *Auth0Manager) downloadProfileExport(ctx context.Context, location string) (map[string][]*UserData, error) {
|
||||
body, err := doGetReq(ctx, am.httpClient, location, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyReader := bytes.NewReader(body)
|
||||
|
||||
gzipReader, err := gzip.NewReader(bodyReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(gzipReader)
|
||||
|
||||
res := make(map[string][]*UserData)
|
||||
for decoder.More() {
|
||||
profile := auth0Profile{}
|
||||
err = decoder.Decode(&profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if profile.AccountID != "" {
|
||||
if _, ok := res[profile.AccountID]; !ok {
|
||||
res[profile.AccountID] = []*UserData{}
|
||||
}
|
||||
res[profile.AccountID] = append(res[profile.AccountID],
|
||||
&UserData{
|
||||
ID: profile.UserID,
|
||||
Name: profile.Name,
|
||||
Email: profile.Email,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: profile.AccountID,
|
||||
WTPendingInvite: &profile.PendingInvite,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Boilerplate implementation for Get Requests.
|
||||
func doGetReq(ctx context.Context, client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if accessToken != "" {
|
||||
req.Header.Add("authorization", "Bearer "+accessToken)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing body for url %s: %v", url, err)
|
||||
}
|
||||
}()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
@@ -1,473 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
code int
|
||||
resBody string
|
||||
reqBody string
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
if req.Body != nil {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err == nil {
|
||||
c.reqBody = string(body)
|
||||
}
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: c.code,
|
||||
Body: io.NopCloser(strings.NewReader(c.resBody)),
|
||||
}, c.err
|
||||
}
|
||||
|
||||
type mockJsonParser struct {
|
||||
jsonParser JsonParser
|
||||
marshalErrorString string
|
||||
unmarshalErrorString string
|
||||
}
|
||||
|
||||
func (m *mockJsonParser) Marshal(v interface{}) ([]byte, error) {
|
||||
if m.marshalErrorString != "" {
|
||||
return nil, errors.New(m.marshalErrorString)
|
||||
}
|
||||
return m.jsonParser.Marshal(v)
|
||||
}
|
||||
|
||||
func (m *mockJsonParser) Unmarshal(data []byte, v interface{}) error {
|
||||
if m.unmarshalErrorString != "" {
|
||||
return errors.New(m.unmarshalErrorString)
|
||||
}
|
||||
return m.jsonParser.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
type mockAuth0Credentials struct {
|
||||
jwtToken JWTToken
|
||||
err error
|
||||
}
|
||||
|
||||
func (mc *mockAuth0Credentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return mc.jwtToken, mc.err
|
||||
}
|
||||
|
||||
func newTestJWT(t *testing.T, expInt int) string {
|
||||
t.Helper()
|
||||
now := time.Now()
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(time.Duration(expInt) * time.Second).Unix(),
|
||||
})
|
||||
var hmacSampleSecret []byte
|
||||
tokenString, err := token.SignedString(hmacSampleSecret)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func TestAuth0_RequestJWTToken(t *testing.T) {
|
||||
|
||||
type requestJWTTokenTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||
name: "Request Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
creds := Auth0Credentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
res, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "unable to read the response body")
|
||||
|
||||
jwtToken := JWTToken{}
|
||||
err = json.Unmarshal(body, &jwtToken)
|
||||
assert.NoError(t, err, "unable to parse the json input")
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth0_ParseRequestJWTResponse(t *testing.T) {
|
||||
type parseRequestJWTResponseTest struct {
|
||||
name string
|
||||
inputResBody string
|
||||
helper ManagerHelper
|
||||
expectedToken string
|
||||
expectedExpiresIn int
|
||||
assertErrFunc func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 100
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||
name: "Parse Good JWT Body",
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
expectedExpiresIn: exp,
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "no error was expected",
|
||||
}
|
||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||
name: "Parse Bad json JWT Body",
|
||||
inputResBody: "",
|
||||
helper: JsonParser{},
|
||||
expectedToken: "",
|
||||
expectedExpiresIn: 0,
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "json error was expected",
|
||||
}
|
||||
|
||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputResBody))
|
||||
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
creds := Auth0Credentials{
|
||||
clientConfig: config,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth0_JwtStillValid(t *testing.T) {
|
||||
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
creds := Auth0Credentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth0_Authenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
// expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
creds := Auth0Credentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
|
||||
|
||||
type updateUserAppMetadataTest struct {
|
||||
name string
|
||||
inputReqBody string
|
||||
expectedReqBody string
|
||||
appMetadata AppMetadata
|
||||
statusCode int
|
||||
helper ManagerHelper
|
||||
managerCreds ManagerCredentials
|
||||
assertErrFunc func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 15
|
||||
token := newTestJWT(t, exp)
|
||||
appMetadata := AppMetadata{WTAccountID: "ok"}
|
||||
|
||||
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
||||
name: "Bad Authentication",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: "",
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockAuth0Credentials{
|
||||
jwtToken: JWTToken{},
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
||||
name: "Bad Status Code",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockAuth0Credentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
||||
name: "Bad Response Parsing",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
statusCode: 400,
|
||||
helper: &mockJsonParser{marshalErrorString: "error"},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
||||
name: "Good request",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 200,
|
||||
helper: JsonParser{},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
invite := true
|
||||
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
|
||||
name: "Update Pending Invite",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":true}}", appMetadata.WTAccountID),
|
||||
appMetadata: AppMetadata{
|
||||
WTAccountID: "ok",
|
||||
WTPendingInvite: &invite,
|
||||
},
|
||||
statusCode: 200,
|
||||
helper: JsonParser{},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
||||
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputReqBody,
|
||||
code: testCase.statusCode,
|
||||
}
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
var creds ManagerCredentials
|
||||
if testCase.managerCreds != nil {
|
||||
creds = testCase.managerCreds
|
||||
} else {
|
||||
creds = &Auth0Credentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
}
|
||||
|
||||
manager := Auth0Manager{
|
||||
httpClient: &jwtReqClient,
|
||||
credentials: creds,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
err := manager.UpdateUserAppMetadata(context.Background(), "1", testCase.appMetadata)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equal(t, testCase.expectedReqBody, jwtReqClient.reqBody, "request body should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuth0Manager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig Auth0ClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := Auth0ClientConfig{
|
||||
AuthIssuer: "https://abc-auth0.eu.auth0.com",
|
||||
Audience: "https://abc-auth0.eu.auth0.com/api/v2/",
|
||||
ClientID: "abcdefg",
|
||||
ClientSecret: "supersecret",
|
||||
GrantType: "client_credentials",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Scenario With Config",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.ClientID = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "shouldn't return error when field empty",
|
||||
}
|
||||
|
||||
testCase3Config := defaultTestConfig
|
||||
testCase3Config.AuthIssuer = "abc-auth0.eu.auth0.com"
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,428 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api/v3"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// AuthentikManager authentik manager client instance.
|
||||
type AuthentikManager struct {
|
||||
apiClient *api.APIClient
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// AuthentikClientConfig authentik manager client configurations.
|
||||
type AuthentikClientConfig struct {
|
||||
Issuer string
|
||||
ClientID string
|
||||
Username string
|
||||
Password string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// AuthentikCredentials authentik authentication information.
|
||||
type AuthentikCredentials struct {
|
||||
clientConfig AuthentikClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// NewAuthentikManager creates a new instance of the AuthentikManager.
|
||||
func NewAuthentikManager(config AuthentikClientConfig,
|
||||
appMetrics telemetry.AppMetrics) (*AuthentikManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, clientID is missing")
|
||||
}
|
||||
|
||||
if config.Username == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Username is missing")
|
||||
}
|
||||
|
||||
if config.Password == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Password is missing")
|
||||
}
|
||||
|
||||
if config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, TokenEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.Issuer == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Issuer is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
// authentik client configuration
|
||||
issuerURL, err := url.Parse(config.Issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authentikConfig := api.NewConfiguration()
|
||||
authentikConfig.HTTPClient = httpClient
|
||||
authentikConfig.Host = issuerURL.Host
|
||||
authentikConfig.Scheme = issuerURL.Scheme
|
||||
|
||||
credentials := &AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &AuthentikManager{
|
||||
apiClient: api.NewAPIClient(authentikConfig),
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from authentik.
|
||||
func (ac *AuthentikCredentials) jwtStillValid() bool {
|
||||
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (ac *AuthentikCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", ac.clientConfig.ClientID)
|
||||
data.Set("username", ac.clientConfig.Username)
|
||||
data.Set("password", ac.clientConfig.Password)
|
||||
data.Set("grant_type", ac.clientConfig.GrantType)
|
||||
data.Set("scope", "goauthentik.io/api")
|
||||
|
||||
payload := strings.NewReader(data.Encode())
|
||||
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for authentik idp manager")
|
||||
|
||||
resp, err := ac.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if ac.appMetrics != nil {
|
||||
ac.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unable to get authentik token, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = ac.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = ac.helper.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the authentik management API.
|
||||
func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
ac.mux.Lock()
|
||||
defer ac.mux.Unlock()
|
||||
|
||||
if ac.appMetrics != nil {
|
||||
ac.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// reuse the token without requesting a new one if it is not expired,
|
||||
// and if expiry time is sufficient time available to make a request.
|
||||
if ac.jwtStillValid() {
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := ac.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
|
||||
ac.jwtToken = jwtToken
|
||||
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (am *AuthentikManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from authentik via ID.
|
||||
func (am *AuthentikManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userPk, err := strconv.ParseInt(userID, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, resp, err := am.apiClient.CoreApi.CoreUsersRetrieve(ctx, int32(userPk)).Execute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
userData := parseAuthentikUser(*user)
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (am *AuthentikManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
for index, user := range users {
|
||||
user.AppMetadata.WTAccountID = accountID
|
||||
users[index] = user
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *AuthentikManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in a Authentik account.
|
||||
func (am *AuthentikManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
users := make([]*UserData, 0)
|
||||
|
||||
page := int32(1)
|
||||
for {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Page(page).Execute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
for _, user := range userList.Results {
|
||||
users = append(users, parseAuthentikUser(user))
|
||||
}
|
||||
|
||||
page = int32(userList.GetPagination().Next)
|
||||
if userList.GetPagination().Next == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in authentik Idp and sends an invitation.
|
||||
func (am *AuthentikManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (am *AuthentikManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Email(email).Execute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, user := range userList.Results {
|
||||
users = append(users, parseAuthentikUser(user))
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *AuthentikManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Authentik
|
||||
func (am *AuthentikManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userPk, err := strconv.ParseInt(userID, 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.apiClient.CoreApi.CoreUsersDestroy(ctx, int32(userPk)).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close() // nolint
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *AuthentikManager) authenticationContext(ctx context.Context) (context.Context, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value := map[string]api.APIKey{
|
||||
"authentik": {
|
||||
Key: jwtToken.AccessToken,
|
||||
Prefix: jwtToken.TokenType,
|
||||
},
|
||||
}
|
||||
return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil
|
||||
}
|
||||
|
||||
func parseAuthentikUser(user api.User) *UserData {
|
||||
return &UserData{
|
||||
Email: *user.Email,
|
||||
Name: user.Name,
|
||||
ID: strconv.FormatInt(int64(user.Pk), 10),
|
||||
}
|
||||
}
|
||||
@@ -1,320 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewAuthentikManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig AuthentikClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := AuthentikClientConfig{
|
||||
ClientID: "client_id",
|
||||
Username: "username",
|
||||
Password: "password",
|
||||
TokenEndpoint: "https://localhost:8080/application/o/token/",
|
||||
Issuer: "https://localhost:8080/application/o/netbird/",
|
||||
GrantType: "client_credentials",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.ClientID = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing ClientID Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase3Config := defaultTestConfig
|
||||
testCase3Config.Username = ""
|
||||
|
||||
testCase3 := test{
|
||||
name: "Missing Username Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase4Config := defaultTestConfig
|
||||
testCase4Config.Password = ""
|
||||
|
||||
testCase4 := test{
|
||||
name: "Missing Password Configuration",
|
||||
inputConfig: testCase4Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase5Config := defaultTestConfig
|
||||
testCase5Config.GrantType = ""
|
||||
|
||||
testCase5 := test{
|
||||
name: "Missing GrantType Configuration",
|
||||
inputConfig: testCase5Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase6Config := defaultTestConfig
|
||||
testCase6Config.Issuer = ""
|
||||
|
||||
testCase6 := test{
|
||||
name: "Missing Issuer Configuration",
|
||||
inputConfig: testCase6Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewAuthentikManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthentikRequestJWTToken(t *testing.T) {
|
||||
type requestJWTTokenTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||
name: "Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
}
|
||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||
name: "Request Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputRespBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get authentik token, statusCode 400"),
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputRespBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := AuthentikClientConfig{}
|
||||
|
||||
creds := AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err, "unable to read the response body")
|
||||
|
||||
jwtToken := JWTToken{}
|
||||
err = testCase.helper.Unmarshal(body, &jwtToken)
|
||||
assert.NoError(t, err, "unable to parse the json input")
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthentikParseRequestJWTResponse(t *testing.T) {
|
||||
type parseRequestJWTResponseTest struct {
|
||||
name string
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedToken string
|
||||
expectedExpiresIn int
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 100
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||
name: "Parse Good JWT Body",
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
expectedExpiresIn: exp,
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "no error was expected",
|
||||
}
|
||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||
name: "Parse Bad json JWT Body",
|
||||
inputRespBody: "",
|
||||
helper: JsonParser{},
|
||||
expectedToken: "",
|
||||
expectedExpiresIn: 0,
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "json error was expected",
|
||||
}
|
||||
|
||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
|
||||
config := AuthentikClientConfig{}
|
||||
|
||||
creds := AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthentikJwtStillValid(t *testing.T) {
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
config := AuthentikClientConfig{}
|
||||
|
||||
creds := AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthentikAuthenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: nil,
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get authentik token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := AuthentikClientConfig{}
|
||||
|
||||
creds := AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,459 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const profileFields = "id,displayName,mail,userPrincipalName"
|
||||
|
||||
// AzureManager azure manager client instance.
|
||||
type AzureManager struct {
|
||||
ClientID string
|
||||
ObjectID string
|
||||
GraphAPIEndpoint string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// AzureClientConfig azure manager client configurations.
|
||||
type AzureClientConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ObjectID string
|
||||
GraphAPIEndpoint string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// AzureCredentials azure authentication information.
|
||||
type AzureCredentials struct {
|
||||
clientConfig AzureClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// azureProfile represents an azure user profile.
|
||||
type azureProfile map[string]any
|
||||
|
||||
// NewAzureManager creates a new instance of the AzureManager.
|
||||
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, clientID is missing")
|
||||
}
|
||||
|
||||
if config.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, ClientSecret is missing")
|
||||
}
|
||||
|
||||
if config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, TokenEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.GraphAPIEndpoint == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, GraphAPIEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.ObjectID == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, ObjectID is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
credentials := &AzureCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &AzureManager{
|
||||
ObjectID: config.ObjectID,
|
||||
ClientID: config.ClientID,
|
||||
GraphAPIEndpoint: config.GraphAPIEndpoint,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
|
||||
func (ac *AzureCredentials) jwtStillValid() bool {
|
||||
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", ac.clientConfig.ClientID)
|
||||
data.Set("client_secret", ac.clientConfig.ClientSecret)
|
||||
data.Set("grant_type", ac.clientConfig.GrantType)
|
||||
parsedURL, err := url.Parse(ac.clientConfig.GraphAPIEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// get base url and add "/.default" as scope
|
||||
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
|
||||
scopeURL := baseURL + "/.default"
|
||||
data.Set("scope", scopeURL)
|
||||
|
||||
payload := strings.NewReader(data.Encode())
|
||||
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for azure idp manager")
|
||||
|
||||
resp, err := ac.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if ac.appMetrics != nil {
|
||||
ac.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unable to get azure token, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = ac.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = ac.helper.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the azure Management API.
|
||||
func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
ac.mux.Lock()
|
||||
defer ac.mux.Unlock()
|
||||
|
||||
if ac.appMetrics != nil {
|
||||
ac.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// reuse the token without requesting a new one if it is not expired,
|
||||
// and if expiry time is sufficient time available to make a request.
|
||||
if ac.jwtStillValid() {
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := ac.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
|
||||
ac.jwtToken = jwtToken
|
||||
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in azure AD Idp.
|
||||
func (am *AzureManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get(ctx, "users/"+userID, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var profile azureProfile
|
||||
err = am.helper.Unmarshal(body, &profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get(ctx, "users/"+email, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
var profile azureProfile
|
||||
err = am.helper.Unmarshal(body, &profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
users = append(users, profile.userData())
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
for index, user := range users {
|
||||
user.AppMetadata.WTAccountID = accountID
|
||||
users[index] = user
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID.
|
||||
func (am *AzureManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *AzureManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Azure.
|
||||
func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID))
|
||||
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.WithContext(ctx).Debugf("delete idp user %s", userID)
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in an Azure AD account.
|
||||
func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
users := make([]*UserData, 0)
|
||||
|
||||
q := url.Values{}
|
||||
q.Add("$select", profileFields)
|
||||
q.Add("$top", "500")
|
||||
|
||||
for nextLink := "users"; nextLink != ""; {
|
||||
body, err := am.get(ctx, nextLink, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var profiles struct {
|
||||
Value []azureProfile
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
}
|
||||
err = am.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, profile := range profiles.Value {
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
|
||||
nextLink = profiles.NextLink
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (am *AzureManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var reqURL string
|
||||
if strings.HasPrefix(resource, "https") {
|
||||
// Already an absolute URL for paging
|
||||
reqURL = resource
|
||||
} else {
|
||||
reqURL = fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// userData construct user data from keycloak profile.
|
||||
func (ap azureProfile) userData() *UserData {
|
||||
id, ok := ap["id"].(string)
|
||||
if !ok {
|
||||
id = ""
|
||||
}
|
||||
|
||||
email, ok := ap["userPrincipalName"].(string)
|
||||
if !ok {
|
||||
email = ""
|
||||
}
|
||||
|
||||
name, ok := ap["displayName"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
|
||||
return &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: id,
|
||||
}
|
||||
}
|
||||
@@ -1,178 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAzureJwtStillValid(t *testing.T) {
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
config := AzureClientConfig{}
|
||||
|
||||
creds := AzureCredentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureAuthenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: nil,
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get azure token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := AzureClientConfig{}
|
||||
|
||||
creds := AzureCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProfile(t *testing.T) {
|
||||
type azureProfileTest struct {
|
||||
name string
|
||||
invite bool
|
||||
inputProfile azureProfile
|
||||
expectedUserData UserData
|
||||
}
|
||||
|
||||
azureProfileTestCase1 := azureProfileTest{
|
||||
name: "Good Request",
|
||||
invite: false,
|
||||
inputProfile: azureProfile{
|
||||
"id": "test1",
|
||||
"displayName": "John Doe",
|
||||
"userPrincipalName": "test1@test.com",
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
Email: "test1@test.com",
|
||||
Name: "John Doe",
|
||||
ID: "test1",
|
||||
},
|
||||
}
|
||||
|
||||
azureProfileTestCase2 := azureProfileTest{
|
||||
name: "Missing User ID",
|
||||
invite: true,
|
||||
inputProfile: azureProfile{
|
||||
"displayName": "John Doe",
|
||||
"userPrincipalName": "test2@test.com",
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
Email: "test2@test.com",
|
||||
Name: "John Doe",
|
||||
},
|
||||
}
|
||||
|
||||
azureProfileTestCase3 := azureProfileTest{
|
||||
name: "Missing User Name",
|
||||
invite: false,
|
||||
inputProfile: azureProfile{
|
||||
"id": "test3",
|
||||
"userPrincipalName": "test3@test.com",
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
ID: "test3",
|
||||
Email: "test3@test.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
||||
userData := testCase.inputProfile.userData()
|
||||
|
||||
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
|
||||
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
457
management/server/idp/connector.go
Normal file
457
management/server/idp/connector.go
Normal file
@@ -0,0 +1,457 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ConnectorType represents the type of external identity provider connector
|
||||
type ConnectorType string
|
||||
|
||||
const (
|
||||
ConnectorTypeOIDC ConnectorType = "oidc"
|
||||
ConnectorTypeLDAP ConnectorType = "ldap"
|
||||
ConnectorTypeSAML ConnectorType = "saml"
|
||||
)
|
||||
|
||||
// Connector represents an external identity provider configured in Zitadel
|
||||
type Connector struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type ConnectorType `json:"type"`
|
||||
State string `json:"state"`
|
||||
Issuer string `json:"issuer,omitempty"` // for OIDC
|
||||
Servers []string `json:"servers,omitempty"` // for LDAP
|
||||
}
|
||||
|
||||
// OIDCConnectorConfig contains configuration for adding an OIDC connector
|
||||
type OIDCConnectorConfig struct {
|
||||
Name string `json:"name"`
|
||||
Issuer string `json:"issuer"`
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
IsIDTokenMapping bool `json:"isIdTokenMapping,omitempty"`
|
||||
IsAutoCreation bool `json:"isAutoCreation,omitempty"`
|
||||
IsAutoUpdate bool `json:"isAutoUpdate,omitempty"`
|
||||
IsCreationAllowed bool `json:"isCreationAllowed,omitempty"`
|
||||
IsLinkingAllowed bool `json:"isLinkingAllowed,omitempty"`
|
||||
IsAutoAccountLinking bool `json:"isAutoAccountLinking,omitempty"`
|
||||
AccountLinkingEnabled bool `json:"accountLinkingEnabled,omitempty"`
|
||||
}
|
||||
|
||||
// LDAPConnectorConfig contains configuration for adding an LDAP connector
|
||||
type LDAPConnectorConfig struct {
|
||||
Name string `json:"name"`
|
||||
Servers []string `json:"servers"` // e.g., ["ldap://localhost:389"]
|
||||
StartTLS bool `json:"startTls,omitempty"`
|
||||
BaseDN string `json:"baseDn"`
|
||||
BindDN string `json:"bindDn"`
|
||||
BindPassword string `json:"bindPassword"`
|
||||
UserBase string `json:"userBase,omitempty"` // typically "dn"
|
||||
UserObjectClass []string `json:"userObjectClass,omitempty"` // e.g., ["user", "person"]
|
||||
UserFilters []string `json:"userFilters,omitempty"` // e.g., ["uid", "email"]
|
||||
Timeout string `json:"timeout,omitempty"` // e.g., "10s"
|
||||
Attributes LDAPAttributes `json:"attributes,omitempty"`
|
||||
IsAutoCreation bool `json:"isAutoCreation,omitempty"`
|
||||
IsAutoUpdate bool `json:"isAutoUpdate,omitempty"`
|
||||
IsCreationAllowed bool `json:"isCreationAllowed,omitempty"`
|
||||
IsLinkingAllowed bool `json:"isLinkingAllowed,omitempty"`
|
||||
}
|
||||
|
||||
// LDAPAttributes maps LDAP attributes to Zitadel user fields
|
||||
type LDAPAttributes struct {
|
||||
IDAttribute string `json:"idAttribute,omitempty"`
|
||||
FirstNameAttribute string `json:"firstNameAttribute,omitempty"`
|
||||
LastNameAttribute string `json:"lastNameAttribute,omitempty"`
|
||||
DisplayNameAttribute string `json:"displayNameAttribute,omitempty"`
|
||||
NickNameAttribute string `json:"nickNameAttribute,omitempty"`
|
||||
EmailAttribute string `json:"emailAttribute,omitempty"`
|
||||
EmailVerified string `json:"emailVerified,omitempty"`
|
||||
PhoneAttribute string `json:"phoneAttribute,omitempty"`
|
||||
PhoneVerified string `json:"phoneVerified,omitempty"`
|
||||
AvatarURLAttribute string `json:"avatarUrlAttribute,omitempty"`
|
||||
ProfileAttribute string `json:"profileAttribute,omitempty"`
|
||||
}
|
||||
|
||||
// SAMLConnectorConfig contains configuration for adding a SAML connector
|
||||
type SAMLConnectorConfig struct {
|
||||
Name string `json:"name"`
|
||||
MetadataXML string `json:"metadataXml,omitempty"`
|
||||
MetadataURL string `json:"metadataUrl,omitempty"`
|
||||
Binding string `json:"binding,omitempty"` // "SAML_BINDING_POST" or "SAML_BINDING_REDIRECT"
|
||||
WithSignedRequest bool `json:"withSignedRequest,omitempty"`
|
||||
NameIDFormat string `json:"nameIdFormat,omitempty"`
|
||||
IsAutoCreation bool `json:"isAutoCreation,omitempty"`
|
||||
IsAutoUpdate bool `json:"isAutoUpdate,omitempty"`
|
||||
IsCreationAllowed bool `json:"isCreationAllowed,omitempty"`
|
||||
IsLinkingAllowed bool `json:"isLinkingAllowed,omitempty"`
|
||||
}
|
||||
|
||||
// ConnectorManager defines the interface for managing external IdP connectors
|
||||
type ConnectorManager interface {
|
||||
// AddOIDCConnector adds a Generic OIDC identity provider connector
|
||||
AddOIDCConnector(ctx context.Context, config OIDCConnectorConfig) (*Connector, error)
|
||||
// AddLDAPConnector adds an LDAP identity provider connector
|
||||
AddLDAPConnector(ctx context.Context, config LDAPConnectorConfig) (*Connector, error)
|
||||
// AddSAMLConnector adds a SAML identity provider connector
|
||||
AddSAMLConnector(ctx context.Context, config SAMLConnectorConfig) (*Connector, error)
|
||||
// ListConnectors returns all configured identity provider connectors
|
||||
ListConnectors(ctx context.Context) ([]*Connector, error)
|
||||
// GetConnector returns a specific connector by ID
|
||||
GetConnector(ctx context.Context, connectorID string) (*Connector, error)
|
||||
// DeleteConnector removes an identity provider connector
|
||||
DeleteConnector(ctx context.Context, connectorID string) error
|
||||
// ActivateConnector adds the connector to the login policy
|
||||
ActivateConnector(ctx context.Context, connectorID string) error
|
||||
// DeactivateConnector removes the connector from the login policy
|
||||
DeactivateConnector(ctx context.Context, connectorID string) error
|
||||
}
|
||||
|
||||
// zitadelProviderResponse represents the response from creating a provider
|
||||
type zitadelProviderResponse struct {
|
||||
ID string `json:"id"`
|
||||
Details struct {
|
||||
Sequence string `json:"sequence"`
|
||||
CreationDate string `json:"creationDate"`
|
||||
ChangeDate string `json:"changeDate"`
|
||||
ResourceOwner string `json:"resourceOwner"`
|
||||
} `json:"details"`
|
||||
}
|
||||
|
||||
// zitadelProviderTemplate represents a provider in the list response
|
||||
type zitadelProviderTemplate struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
State string `json:"state"` // IDP_STATE_ACTIVE, IDP_STATE_INACTIVE
|
||||
Type string `json:"type"` // IDP_TYPE_OIDC, IDP_TYPE_LDAP, IDP_TYPE_SAML, etc.
|
||||
Owner string `json:"owner"` // IDP_OWNER_TYPE_ORG, IDP_OWNER_TYPE_SYSTEM
|
||||
// Type-specific fields
|
||||
OIDC *struct {
|
||||
Issuer string `json:"issuer"`
|
||||
ClientID string `json:"clientId"`
|
||||
} `json:"oidc,omitempty"`
|
||||
LDAP *struct {
|
||||
Servers []string `json:"servers"`
|
||||
BaseDN string `json:"baseDn"`
|
||||
} `json:"ldap,omitempty"`
|
||||
SAML *struct {
|
||||
MetadataURL string `json:"metadataUrl"`
|
||||
} `json:"saml,omitempty"`
|
||||
}
|
||||
|
||||
// AddOIDCConnector adds a Generic OIDC identity provider connector to Zitadel
|
||||
func (zm *ZitadelManager) AddOIDCConnector(ctx context.Context, config OIDCConnectorConfig) (*Connector, error) {
|
||||
// Set defaults for creation/linking if not specified
|
||||
if !config.IsCreationAllowed && !config.IsLinkingAllowed {
|
||||
config.IsCreationAllowed = true
|
||||
config.IsLinkingAllowed = true
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"name": config.Name,
|
||||
"issuer": config.Issuer,
|
||||
"clientId": config.ClientID,
|
||||
"clientSecret": config.ClientSecret,
|
||||
"isIdTokenMapping": config.IsIDTokenMapping,
|
||||
"isAutoCreation": config.IsAutoCreation,
|
||||
"isAutoUpdate": config.IsAutoUpdate,
|
||||
"isCreationAllowed": config.IsCreationAllowed,
|
||||
"isLinkingAllowed": config.IsLinkingAllowed,
|
||||
}
|
||||
|
||||
if len(config.Scopes) > 0 {
|
||||
payload["scopes"] = config.Scopes
|
||||
}
|
||||
|
||||
body, err := zm.helper.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal OIDC connector config: %w", err)
|
||||
}
|
||||
|
||||
respBody, err := zm.post(ctx, "idps/generic_oidc", string(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add OIDC connector: %w", err)
|
||||
}
|
||||
|
||||
var resp zitadelProviderResponse
|
||||
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal OIDC connector response: %w", err)
|
||||
}
|
||||
|
||||
return &Connector{
|
||||
ID: resp.ID,
|
||||
Name: config.Name,
|
||||
Type: ConnectorTypeOIDC,
|
||||
State: "active",
|
||||
Issuer: config.Issuer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddLDAPConnector adds an LDAP identity provider connector to Zitadel
|
||||
func (zm *ZitadelManager) AddLDAPConnector(ctx context.Context, config LDAPConnectorConfig) (*Connector, error) {
|
||||
// Set defaults
|
||||
if !config.IsCreationAllowed && !config.IsLinkingAllowed {
|
||||
config.IsCreationAllowed = true
|
||||
config.IsLinkingAllowed = true
|
||||
}
|
||||
if config.UserBase == "" {
|
||||
config.UserBase = "dn"
|
||||
}
|
||||
if config.Timeout == "" {
|
||||
config.Timeout = "10s"
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"name": config.Name,
|
||||
"servers": config.Servers,
|
||||
"startTls": config.StartTLS,
|
||||
"baseDn": config.BaseDN,
|
||||
"bindDn": config.BindDN,
|
||||
"bindPassword": config.BindPassword,
|
||||
"userBase": config.UserBase,
|
||||
"timeout": config.Timeout,
|
||||
"isAutoCreation": config.IsAutoCreation,
|
||||
"isAutoUpdate": config.IsAutoUpdate,
|
||||
"isCreationAllowed": config.IsCreationAllowed,
|
||||
"isLinkingAllowed": config.IsLinkingAllowed,
|
||||
}
|
||||
|
||||
if len(config.UserObjectClass) > 0 {
|
||||
payload["userObjectClasses"] = config.UserObjectClass
|
||||
}
|
||||
if len(config.UserFilters) > 0 {
|
||||
payload["userFilters"] = config.UserFilters
|
||||
}
|
||||
|
||||
// Add attribute mappings if provided
|
||||
attrs := make(map[string]string)
|
||||
if config.Attributes.IDAttribute != "" {
|
||||
attrs["idAttribute"] = config.Attributes.IDAttribute
|
||||
}
|
||||
if config.Attributes.FirstNameAttribute != "" {
|
||||
attrs["firstNameAttribute"] = config.Attributes.FirstNameAttribute
|
||||
}
|
||||
if config.Attributes.LastNameAttribute != "" {
|
||||
attrs["lastNameAttribute"] = config.Attributes.LastNameAttribute
|
||||
}
|
||||
if config.Attributes.DisplayNameAttribute != "" {
|
||||
attrs["displayNameAttribute"] = config.Attributes.DisplayNameAttribute
|
||||
}
|
||||
if config.Attributes.EmailAttribute != "" {
|
||||
attrs["emailAttribute"] = config.Attributes.EmailAttribute
|
||||
}
|
||||
if len(attrs) > 0 {
|
||||
payload["attributes"] = attrs
|
||||
}
|
||||
|
||||
body, err := zm.helper.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal LDAP connector config: %w", err)
|
||||
}
|
||||
|
||||
respBody, err := zm.post(ctx, "idps/ldap", string(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add LDAP connector: %w", err)
|
||||
}
|
||||
|
||||
var resp zitadelProviderResponse
|
||||
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal LDAP connector response: %w", err)
|
||||
}
|
||||
|
||||
return &Connector{
|
||||
ID: resp.ID,
|
||||
Name: config.Name,
|
||||
Type: ConnectorTypeLDAP,
|
||||
State: "active",
|
||||
Servers: config.Servers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddSAMLConnector adds a SAML identity provider connector to Zitadel
|
||||
func (zm *ZitadelManager) AddSAMLConnector(ctx context.Context, config SAMLConnectorConfig) (*Connector, error) {
|
||||
// Set defaults
|
||||
if !config.IsCreationAllowed && !config.IsLinkingAllowed {
|
||||
config.IsCreationAllowed = true
|
||||
config.IsLinkingAllowed = true
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"name": config.Name,
|
||||
"isAutoCreation": config.IsAutoCreation,
|
||||
"isAutoUpdate": config.IsAutoUpdate,
|
||||
"isCreationAllowed": config.IsCreationAllowed,
|
||||
"isLinkingAllowed": config.IsLinkingAllowed,
|
||||
}
|
||||
|
||||
if config.MetadataXML != "" {
|
||||
payload["metadataXml"] = config.MetadataXML
|
||||
} else if config.MetadataURL != "" {
|
||||
payload["metadataUrl"] = config.MetadataURL
|
||||
} else {
|
||||
return nil, fmt.Errorf("either metadataXml or metadataUrl must be provided")
|
||||
}
|
||||
|
||||
if config.Binding != "" {
|
||||
payload["binding"] = config.Binding
|
||||
}
|
||||
if config.WithSignedRequest {
|
||||
payload["withSignedRequest"] = config.WithSignedRequest
|
||||
}
|
||||
if config.NameIDFormat != "" {
|
||||
payload["nameIdFormat"] = config.NameIDFormat
|
||||
}
|
||||
|
||||
body, err := zm.helper.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal SAML connector config: %w", err)
|
||||
}
|
||||
|
||||
respBody, err := zm.post(ctx, "idps/saml", string(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add SAML connector: %w", err)
|
||||
}
|
||||
|
||||
var resp zitadelProviderResponse
|
||||
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal SAML connector response: %w", err)
|
||||
}
|
||||
|
||||
return &Connector{
|
||||
ID: resp.ID,
|
||||
Name: config.Name,
|
||||
Type: ConnectorTypeSAML,
|
||||
State: "active",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListConnectors returns all configured identity provider connectors
|
||||
func (zm *ZitadelManager) ListConnectors(ctx context.Context) ([]*Connector, error) {
|
||||
// Use the search endpoint to list all providers
|
||||
respBody, err := zm.post(ctx, "idps/_search", "{}")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list connectors: %w", err)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Result []zitadelProviderTemplate `json:"result"`
|
||||
}
|
||||
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal connectors response: %w", err)
|
||||
}
|
||||
|
||||
connectors := make([]*Connector, 0, len(resp.Result))
|
||||
for _, p := range resp.Result {
|
||||
connector := &Connector{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
State: normalizeState(p.State),
|
||||
Type: normalizeType(p.Type),
|
||||
}
|
||||
|
||||
// Add type-specific fields
|
||||
if p.OIDC != nil {
|
||||
connector.Issuer = p.OIDC.Issuer
|
||||
}
|
||||
if p.LDAP != nil {
|
||||
connector.Servers = p.LDAP.Servers
|
||||
}
|
||||
|
||||
connectors = append(connectors, connector)
|
||||
}
|
||||
|
||||
return connectors, nil
|
||||
}
|
||||
|
||||
// GetConnector returns a specific connector by ID
|
||||
func (zm *ZitadelManager) GetConnector(ctx context.Context, connectorID string) (*Connector, error) {
|
||||
respBody, err := zm.get(ctx, fmt.Sprintf("idps/%s", connectorID), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get connector: %w", err)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
IDP zitadelProviderTemplate `json:"idp"`
|
||||
}
|
||||
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal connector response: %w", err)
|
||||
}
|
||||
|
||||
connector := &Connector{
|
||||
ID: resp.IDP.ID,
|
||||
Name: resp.IDP.Name,
|
||||
State: normalizeState(resp.IDP.State),
|
||||
Type: normalizeType(resp.IDP.Type),
|
||||
}
|
||||
|
||||
if resp.IDP.OIDC != nil {
|
||||
connector.Issuer = resp.IDP.OIDC.Issuer
|
||||
}
|
||||
if resp.IDP.LDAP != nil {
|
||||
connector.Servers = resp.IDP.LDAP.Servers
|
||||
}
|
||||
|
||||
return connector, nil
|
||||
}
|
||||
|
||||
// DeleteConnector removes an identity provider connector
|
||||
func (zm *ZitadelManager) DeleteConnector(ctx context.Context, connectorID string) error {
|
||||
if err := zm.delete(ctx, fmt.Sprintf("idps/%s", connectorID)); err != nil {
|
||||
return fmt.Errorf("delete connector: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ActivateConnector adds the connector to the organization's login policy
|
||||
func (zm *ZitadelManager) ActivateConnector(ctx context.Context, connectorID string) error {
|
||||
payload := map[string]string{
|
||||
"idpId": connectorID,
|
||||
}
|
||||
|
||||
body, err := zm.helper.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal activate request: %w", err)
|
||||
}
|
||||
|
||||
_, err = zm.post(ctx, "policies/login/idps", string(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("activate connector: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeactivateConnector removes the connector from the organization's login policy
|
||||
func (zm *ZitadelManager) DeactivateConnector(ctx context.Context, connectorID string) error {
|
||||
if err := zm.delete(ctx, fmt.Sprintf("policies/login/idps/%s", connectorID)); err != nil {
|
||||
return fmt.Errorf("deactivate connector: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizeState converts Zitadel state to a simple string
|
||||
func normalizeState(state string) string {
|
||||
switch state {
|
||||
case "IDP_STATE_ACTIVE":
|
||||
return "active"
|
||||
case "IDP_STATE_INACTIVE":
|
||||
return "inactive"
|
||||
default:
|
||||
return state
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeType converts Zitadel type to ConnectorType
|
||||
func normalizeType(idpType string) ConnectorType {
|
||||
switch idpType {
|
||||
case "IDP_TYPE_OIDC", "IDP_TYPE_OIDC_GENERIC":
|
||||
return ConnectorTypeOIDC
|
||||
case "IDP_TYPE_LDAP":
|
||||
return ConnectorTypeLDAP
|
||||
case "IDP_TYPE_SAML":
|
||||
return ConnectorTypeSAML
|
||||
default:
|
||||
return ConnectorType(idpType)
|
||||
}
|
||||
}
|
||||
313
management/server/idp/connector_test.go
Normal file
313
management/server/idp/connector_test.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestZitadelManager_AddOIDCConnector(t *testing.T) {
|
||||
// Create a mock response for the OIDC connector creation
|
||||
mockResponse := `{"id": "oidc-123", "details": {"sequence": "1", "creationDate": "2024-01-01T00:00:00Z", "changeDate": "2024-01-01T00:00:00Z", "resourceOwner": "org-1"}}`
|
||||
|
||||
mockClient := &mockHTTPClient{
|
||||
code: http.StatusCreated,
|
||||
resBody: mockResponse,
|
||||
}
|
||||
|
||||
manager := &ZitadelManager{
|
||||
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||
httpClient: mockClient,
|
||||
credentials: &mockCredentials{token: "test-token"},
|
||||
helper: JsonParser{},
|
||||
}
|
||||
|
||||
config := OIDCConnectorConfig{
|
||||
Name: "Okta",
|
||||
Issuer: "https://okta.example.com",
|
||||
ClientID: "client-123",
|
||||
ClientSecret: "secret-456",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
}
|
||||
|
||||
connector, err := manager.AddOIDCConnector(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "oidc-123", connector.ID)
|
||||
assert.Equal(t, "Okta", connector.Name)
|
||||
assert.Equal(t, ConnectorTypeOIDC, connector.Type)
|
||||
assert.Equal(t, "https://okta.example.com", connector.Issuer)
|
||||
|
||||
// Verify the request body contains expected fields
|
||||
var reqBody map[string]any
|
||||
err = json.Unmarshal([]byte(mockClient.reqBody), &reqBody)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Okta", reqBody["name"])
|
||||
assert.Equal(t, "https://okta.example.com", reqBody["issuer"])
|
||||
assert.Equal(t, "client-123", reqBody["clientId"])
|
||||
}
|
||||
|
||||
func TestZitadelManager_AddLDAPConnector(t *testing.T) {
|
||||
mockResponse := `{"id": "ldap-456", "details": {"sequence": "1", "creationDate": "2024-01-01T00:00:00Z", "changeDate": "2024-01-01T00:00:00Z", "resourceOwner": "org-1"}}`
|
||||
|
||||
mockClient := &mockHTTPClient{
|
||||
code: http.StatusCreated,
|
||||
resBody: mockResponse,
|
||||
}
|
||||
|
||||
manager := &ZitadelManager{
|
||||
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||
httpClient: mockClient,
|
||||
credentials: &mockCredentials{token: "test-token"},
|
||||
helper: JsonParser{},
|
||||
}
|
||||
|
||||
config := LDAPConnectorConfig{
|
||||
Name: "Corporate LDAP",
|
||||
Servers: []string{"ldap://ldap.example.com:389"},
|
||||
BaseDN: "dc=example,dc=com",
|
||||
BindDN: "cn=admin,dc=example,dc=com",
|
||||
BindPassword: "admin-password",
|
||||
Attributes: LDAPAttributes{
|
||||
IDAttribute: "uid",
|
||||
EmailAttribute: "mail",
|
||||
},
|
||||
}
|
||||
|
||||
connector, err := manager.AddLDAPConnector(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ldap-456", connector.ID)
|
||||
assert.Equal(t, "Corporate LDAP", connector.Name)
|
||||
assert.Equal(t, ConnectorTypeLDAP, connector.Type)
|
||||
assert.Equal(t, []string{"ldap://ldap.example.com:389"}, connector.Servers)
|
||||
}
|
||||
|
||||
func TestZitadelManager_AddSAMLConnector(t *testing.T) {
|
||||
mockResponse := `{"id": "saml-789", "details": {"sequence": "1", "creationDate": "2024-01-01T00:00:00Z", "changeDate": "2024-01-01T00:00:00Z", "resourceOwner": "org-1"}}`
|
||||
|
||||
mockClient := &mockHTTPClient{
|
||||
code: http.StatusCreated,
|
||||
resBody: mockResponse,
|
||||
}
|
||||
|
||||
manager := &ZitadelManager{
|
||||
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||
httpClient: mockClient,
|
||||
credentials: &mockCredentials{token: "test-token"},
|
||||
helper: JsonParser{},
|
||||
}
|
||||
|
||||
config := SAMLConnectorConfig{
|
||||
Name: "Enterprise SAML",
|
||||
MetadataURL: "https://idp.example.com/metadata.xml",
|
||||
}
|
||||
|
||||
connector, err := manager.AddSAMLConnector(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "saml-789", connector.ID)
|
||||
assert.Equal(t, "Enterprise SAML", connector.Name)
|
||||
assert.Equal(t, ConnectorTypeSAML, connector.Type)
|
||||
}
|
||||
|
||||
func TestZitadelManager_AddSAMLConnector_RequiresMetadata(t *testing.T) {
|
||||
manager := &ZitadelManager{
|
||||
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||
httpClient: &mockHTTPClient{},
|
||||
credentials: &mockCredentials{token: "test-token"},
|
||||
helper: JsonParser{},
|
||||
}
|
||||
|
||||
config := SAMLConnectorConfig{
|
||||
Name: "Invalid SAML",
|
||||
// Neither MetadataXML nor MetadataURL provided
|
||||
}
|
||||
|
||||
_, err := manager.AddSAMLConnector(context.Background(), config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "metadataXml or metadataUrl must be provided")
|
||||
}
|
||||
|
||||
func TestZitadelManager_ListConnectors(t *testing.T) {
|
||||
mockResponse := `{
|
||||
"result": [
|
||||
{
|
||||
"id": "oidc-1",
|
||||
"name": "Google",
|
||||
"state": "IDP_STATE_ACTIVE",
|
||||
"type": "IDP_TYPE_OIDC",
|
||||
"oidc": {"issuer": "https://accounts.google.com", "clientId": "google-client"}
|
||||
},
|
||||
{
|
||||
"id": "ldap-1",
|
||||
"name": "AD",
|
||||
"state": "IDP_STATE_INACTIVE",
|
||||
"type": "IDP_TYPE_LDAP",
|
||||
"ldap": {"servers": ["ldap://ad.example.com:389"], "baseDn": "dc=example,dc=com"}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
mockClient := &mockHTTPClient{
|
||||
code: http.StatusOK,
|
||||
resBody: mockResponse,
|
||||
}
|
||||
|
||||
manager := &ZitadelManager{
|
||||
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||
httpClient: mockClient,
|
||||
credentials: &mockCredentials{token: "test-token"},
|
||||
helper: JsonParser{},
|
||||
}
|
||||
|
||||
connectors, err := manager.ListConnectors(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, connectors, 2)
|
||||
|
||||
assert.Equal(t, "oidc-1", connectors[0].ID)
|
||||
assert.Equal(t, "Google", connectors[0].Name)
|
||||
assert.Equal(t, "active", connectors[0].State)
|
||||
assert.Equal(t, ConnectorTypeOIDC, connectors[0].Type)
|
||||
assert.Equal(t, "https://accounts.google.com", connectors[0].Issuer)
|
||||
|
||||
assert.Equal(t, "ldap-1", connectors[1].ID)
|
||||
assert.Equal(t, "AD", connectors[1].Name)
|
||||
assert.Equal(t, "inactive", connectors[1].State)
|
||||
assert.Equal(t, ConnectorTypeLDAP, connectors[1].Type)
|
||||
assert.Equal(t, []string{"ldap://ad.example.com:389"}, connectors[1].Servers)
|
||||
}
|
||||
|
||||
func TestZitadelManager_GetConnector(t *testing.T) {
|
||||
mockResponse := `{
|
||||
"idp": {
|
||||
"id": "oidc-123",
|
||||
"name": "Okta",
|
||||
"state": "IDP_STATE_ACTIVE",
|
||||
"type": "IDP_TYPE_OIDC",
|
||||
"oidc": {"issuer": "https://okta.example.com", "clientId": "client-123"}
|
||||
}
|
||||
}`
|
||||
|
||||
mockClient := &mockHTTPClient{
|
||||
code: http.StatusOK,
|
||||
resBody: mockResponse,
|
||||
}
|
||||
|
||||
manager := &ZitadelManager{
|
||||
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||
httpClient: mockClient,
|
||||
credentials: &mockCredentials{token: "test-token"},
|
||||
helper: JsonParser{},
|
||||
}
|
||||
|
||||
connector, err := manager.GetConnector(context.Background(), "oidc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "oidc-123", connector.ID)
|
||||
assert.Equal(t, "Okta", connector.Name)
|
||||
assert.Equal(t, ConnectorTypeOIDC, connector.Type)
|
||||
assert.Equal(t, "https://okta.example.com", connector.Issuer)
|
||||
}
|
||||
|
||||
func TestZitadelManager_DeleteConnector(t *testing.T) {
|
||||
mockClient := &mockHTTPClient{
|
||||
code: http.StatusOK,
|
||||
resBody: "{}",
|
||||
}
|
||||
|
||||
manager := &ZitadelManager{
|
||||
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||
httpClient: mockClient,
|
||||
credentials: &mockCredentials{token: "test-token"},
|
||||
helper: JsonParser{},
|
||||
}
|
||||
|
||||
err := manager.DeleteConnector(context.Background(), "oidc-123")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestZitadelManager_ActivateConnector(t *testing.T) {
|
||||
mockClient := &mockHTTPClient{
|
||||
code: http.StatusOK,
|
||||
resBody: "{}",
|
||||
}
|
||||
|
||||
manager := &ZitadelManager{
|
||||
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||
httpClient: mockClient,
|
||||
credentials: &mockCredentials{token: "test-token"},
|
||||
helper: JsonParser{},
|
||||
}
|
||||
|
||||
err := manager.ActivateConnector(context.Background(), "oidc-123")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the request body
|
||||
var reqBody map[string]string
|
||||
err = json.Unmarshal([]byte(mockClient.reqBody), &reqBody)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "oidc-123", reqBody["idpId"])
|
||||
}
|
||||
|
||||
func TestZitadelManager_DeactivateConnector(t *testing.T) {
|
||||
mockClient := &mockHTTPClient{
|
||||
code: http.StatusOK,
|
||||
resBody: "{}",
|
||||
}
|
||||
|
||||
manager := &ZitadelManager{
|
||||
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||
httpClient: mockClient,
|
||||
credentials: &mockCredentials{token: "test-token"},
|
||||
helper: JsonParser{},
|
||||
}
|
||||
|
||||
err := manager.DeactivateConnector(context.Background(), "oidc-123")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNormalizeState(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"IDP_STATE_ACTIVE", "active"},
|
||||
{"IDP_STATE_INACTIVE", "inactive"},
|
||||
{"custom", "custom"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.input, func(t *testing.T) {
|
||||
assert.Equal(t, tc.expected, normalizeState(tc.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeType(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected ConnectorType
|
||||
}{
|
||||
{"IDP_TYPE_OIDC", ConnectorTypeOIDC},
|
||||
{"IDP_TYPE_OIDC_GENERIC", ConnectorTypeOIDC},
|
||||
{"IDP_TYPE_LDAP", ConnectorTypeLDAP},
|
||||
{"IDP_TYPE_SAML", ConnectorTypeSAML},
|
||||
{"CUSTOM", ConnectorType("CUSTOM")},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.input, func(t *testing.T) {
|
||||
assert.Equal(t, tc.expected, normalizeType(tc.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockCredentials is a mock implementation of ManagerCredentials for testing
|
||||
type mockCredentials struct {
|
||||
token string
|
||||
}
|
||||
|
||||
func (m *mockCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
return JWTToken{AccessToken: m.token}, nil
|
||||
}
|
||||
@@ -1,263 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2/google"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
"google.golang.org/api/option"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// GoogleWorkspaceManager Google Workspace manager client instance.
|
||||
type GoogleWorkspaceManager struct {
|
||||
usersService *admin.UsersService
|
||||
CustomerID string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// GoogleWorkspaceClientConfig Google Workspace manager client configurations.
|
||||
type GoogleWorkspaceClientConfig struct {
|
||||
ServiceAccountKey string
|
||||
CustomerID string
|
||||
}
|
||||
|
||||
// GoogleWorkspaceCredentials Google Workspace authentication information.
|
||||
type GoogleWorkspaceCredentials struct {
|
||||
clientConfig GoogleWorkspaceClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
func (gc *GoogleWorkspaceCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
// NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager.
|
||||
func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.CustomerID == "" {
|
||||
return nil, fmt.Errorf("google IdP configuration is incomplete, CustomerID is missing")
|
||||
}
|
||||
|
||||
credentials := &GoogleWorkspaceCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
// Create a new Admin SDK Directory service client
|
||||
adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service, err := admin.NewService(context.Background(),
|
||||
option.WithScopes(admin.AdminDirectoryUserReadonlyScope),
|
||||
option.WithCredentials(adminCredentials),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GoogleWorkspaceManager{
|
||||
usersService: service.Users,
|
||||
CustomerID: config.CustomerID,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from Google Workspace via ID.
|
||||
func (gm *GoogleWorkspaceManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
user, err := gm.usersService.Get(userID).Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
userData := parseGoogleWorkspaceUser(user)
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (gm *GoogleWorkspaceManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := gm.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
for index, user := range users {
|
||||
user.AppMetadata.WTAccountID = accountID
|
||||
users[index] = user
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (gm *GoogleWorkspaceManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
users, err := gm.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in a Google Workspace account filtered by customer ID.
|
||||
func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) {
|
||||
users := make([]*UserData, 0)
|
||||
pageToken := ""
|
||||
for {
|
||||
call := gm.usersService.List().Customer(gm.CustomerID).MaxResults(500)
|
||||
if pageToken != "" {
|
||||
call.PageToken(pageToken)
|
||||
}
|
||||
|
||||
resp, err := call.Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, user := range resp.Users {
|
||||
users = append(users, parseGoogleWorkspaceUser(user))
|
||||
}
|
||||
|
||||
pageToken = resp.NextPageToken
|
||||
if pageToken == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in Google Workspace and sends an invitation.
|
||||
func (gm *GoogleWorkspaceManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (gm *GoogleWorkspaceManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
user, err := gm.usersService.Get(email).Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
users = append(users, parseGoogleWorkspaceUser(user))
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (gm *GoogleWorkspaceManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from GoogleWorkspace.
|
||||
func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) error {
|
||||
if err := gm.usersService.Delete(userID).Do(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
|
||||
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
|
||||
// If that fails, it falls back to using the default Google credentials path.
|
||||
// It returns the retrieved credentials or an error if unsuccessful.
|
||||
func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) {
|
||||
log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key")
|
||||
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode service account key: %w", err)
|
||||
}
|
||||
|
||||
creds, err := google.CredentialsFromJSON(
|
||||
context.Background(),
|
||||
decodeKey,
|
||||
admin.AdminDirectoryUserReadonlyScope,
|
||||
)
|
||||
if err == nil {
|
||||
// No need to fallback to the default Google credentials path
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
|
||||
log.WithContext(ctx).Debug("falling back to default google credentials location")
|
||||
|
||||
creds, err = google.FindDefaultCredentials(
|
||||
context.Background(),
|
||||
admin.AdminDirectoryUserReadonlyScope,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
// parseGoogleWorkspaceUser parse google user to UserData.
|
||||
func parseGoogleWorkspaceUser(user *admin.User) *UserData {
|
||||
return &UserData{
|
||||
ID: user.Id,
|
||||
Email: user.PrimaryEmail,
|
||||
Name: user.Name.FullName,
|
||||
}
|
||||
}
|
||||
@@ -11,24 +11,25 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
// UnsetAccountID is a special key to map users without an account ID
|
||||
UnsetAccountID = "unset"
|
||||
)
|
||||
|
||||
// Manager idp manager interface
|
||||
// Note: NetBird is the single source of truth for authorization data (roles, account membership, invite status).
|
||||
// The IdP only stores identity information (email, name, credentials).
|
||||
type Manager interface {
|
||||
UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccount(ctx context.Context, accountId string) ([]*UserData, error)
|
||||
GetAllAccounts(ctx context.Context) (map[string][]*UserData, error)
|
||||
CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
// CreateUser creates a new user in the IdP. Returns basic user data (ID, email, name).
|
||||
CreateUser(ctx context.Context, email, name string) (*UserData, error)
|
||||
// GetUserDataByID retrieves user identity data from the IdP by user ID.
|
||||
GetUserDataByID(ctx context.Context, userId string) (*UserData, error)
|
||||
// GetUserByEmail searches for users by email address.
|
||||
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
|
||||
// GetAllUsers returns all users from the IdP for cache warming.
|
||||
GetAllUsers(ctx context.Context) ([]*UserData, error)
|
||||
// InviteUserByID resends an invitation to a user who hasn't completed signup.
|
||||
InviteUserByID(ctx context.Context, userID string) error
|
||||
// DeleteUser removes a user from the IdP.
|
||||
DeleteUser(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// ClientConfig defines common client configuration for all IdP manager
|
||||
// ClientConfig defines common client configuration for the IdP manager
|
||||
type ClientConfig struct {
|
||||
Issuer string
|
||||
TokenEndpoint string
|
||||
@@ -42,13 +43,10 @@ type ExtraConfig map[string]string
|
||||
|
||||
// Config an idp configuration struct to be loaded from management server's config file
|
||||
type Config struct {
|
||||
ManagerType string
|
||||
ClientConfig *ClientConfig
|
||||
ExtraConfig ExtraConfig
|
||||
Auth0ClientCredentials *Auth0ClientConfig
|
||||
AzureClientCredentials *AzureClientConfig
|
||||
KeycloakClientCredentials *KeycloakClientConfig
|
||||
ZitadelClientCredentials *ZitadelClientConfig
|
||||
ManagerType string
|
||||
ClientConfig *ClientConfig
|
||||
ExtraConfig ExtraConfig
|
||||
ZitadelClientCredentials *ZitadelClientConfig
|
||||
}
|
||||
|
||||
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
||||
@@ -67,11 +65,12 @@ type ManagerHelper interface {
|
||||
Unmarshal(data []byte, v interface{}) error
|
||||
}
|
||||
|
||||
// UserData represents identity information from the IdP.
|
||||
// Note: Authorization data (account membership, roles, invite status) is stored in NetBird's DB.
|
||||
type UserData struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
ID string `json:"user_id"`
|
||||
AppMetadata AppMetadata `json:"app_metadata"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
ID string `json:"user_id"`
|
||||
}
|
||||
|
||||
func (u *UserData) MarshalBinary() (data []byte, err error) {
|
||||
@@ -91,15 +90,6 @@ func (u *UserData) Unmarshal(data []byte) (err error) {
|
||||
return json.Unmarshal(data, &u)
|
||||
}
|
||||
|
||||
// AppMetadata user app metadata to associate with a profile
|
||||
type AppMetadata struct {
|
||||
// WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP
|
||||
// maps to wt_account_id when json.marshal
|
||||
WTAccountID string `json:"wt_account_id,omitempty"`
|
||||
WTPendingInvite *bool `json:"wt_pending_invite,omitempty"`
|
||||
WTInvitedBy string `json:"wt_invited_by_email,omitempty"`
|
||||
}
|
||||
|
||||
// JWTToken a JWT object that holds information of a token
|
||||
type JWTToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
@@ -109,7 +99,8 @@ type JWTToken struct {
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// NewManager returns a new idp manager based on the configuration that it receives
|
||||
// NewManager returns a new idp manager based on the configuration that it receives.
|
||||
// Only Zitadel is supported as the IdP manager.
|
||||
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
|
||||
if config.ClientConfig != nil {
|
||||
config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/")
|
||||
@@ -118,46 +109,6 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
||||
switch strings.ToLower(config.ManagerType) {
|
||||
case "none", "":
|
||||
return nil, nil //nolint:nilnil
|
||||
case "auth0":
|
||||
auth0ClientConfig := config.Auth0ClientCredentials
|
||||
if config.ClientConfig != nil {
|
||||
auth0ClientConfig = &Auth0ClientConfig{
|
||||
Audience: config.ExtraConfig["Audience"],
|
||||
AuthIssuer: config.ClientConfig.Issuer,
|
||||
ClientID: config.ClientConfig.ClientID,
|
||||
ClientSecret: config.ClientConfig.ClientSecret,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
}
|
||||
}
|
||||
|
||||
return NewAuth0Manager(*auth0ClientConfig, appMetrics)
|
||||
case "azure":
|
||||
azureClientConfig := config.AzureClientCredentials
|
||||
if config.ClientConfig != nil {
|
||||
azureClientConfig = &AzureClientConfig{
|
||||
ClientID: config.ClientConfig.ClientID,
|
||||
ClientSecret: config.ClientConfig.ClientSecret,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
ObjectID: config.ExtraConfig["ObjectId"],
|
||||
GraphAPIEndpoint: config.ExtraConfig["GraphApiEndpoint"],
|
||||
}
|
||||
}
|
||||
|
||||
return NewAzureManager(*azureClientConfig, appMetrics)
|
||||
case "keycloak":
|
||||
keycloakClientConfig := config.KeycloakClientCredentials
|
||||
if config.ClientConfig != nil {
|
||||
keycloakClientConfig = &KeycloakClientConfig{
|
||||
ClientID: config.ClientConfig.ClientID,
|
||||
ClientSecret: config.ClientConfig.ClientSecret,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
AdminEndpoint: config.ExtraConfig["AdminEndpoint"],
|
||||
}
|
||||
}
|
||||
|
||||
return NewKeycloakManager(*keycloakClientConfig, appMetrics)
|
||||
case "zitadel":
|
||||
zitadelClientConfig := config.ZitadelClientCredentials
|
||||
if config.ClientConfig != nil {
|
||||
@@ -172,42 +123,7 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
||||
}
|
||||
|
||||
return NewZitadelManager(*zitadelClientConfig, appMetrics)
|
||||
case "authentik":
|
||||
authentikConfig := AuthentikClientConfig{
|
||||
Issuer: config.ClientConfig.Issuer,
|
||||
ClientID: config.ClientConfig.ClientID,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
Username: config.ExtraConfig["Username"],
|
||||
Password: config.ExtraConfig["Password"],
|
||||
}
|
||||
return NewAuthentikManager(authentikConfig, appMetrics)
|
||||
case "okta":
|
||||
oktaClientConfig := OktaClientConfig{
|
||||
Issuer: config.ClientConfig.Issuer,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
}
|
||||
return NewOktaManager(oktaClientConfig, appMetrics)
|
||||
case "google":
|
||||
googleClientConfig := GoogleWorkspaceClientConfig{
|
||||
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
|
||||
CustomerID: config.ExtraConfig["CustomerId"],
|
||||
}
|
||||
return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics)
|
||||
case "jumpcloud":
|
||||
jumpcloudConfig := JumpCloudClientConfig{
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
}
|
||||
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
|
||||
case "pocketid":
|
||||
pocketidConfig := PocketIdClientConfig{
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
|
||||
}
|
||||
return NewPocketIdManager(pocketidConfig, appMetrics)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
||||
return nil, fmt.Errorf("unsupported IdP manager type: %s (only 'zitadel' is supported)", config.ManagerType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,257 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v1 "github.com/TheJumpCloud/jcapi-go/v1"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
contentType = "application/json"
|
||||
accept = "application/json"
|
||||
)
|
||||
|
||||
// JumpCloudManager JumpCloud manager client instance.
|
||||
type JumpCloudManager struct {
|
||||
client *v1.APIClient
|
||||
apiToken string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// JumpCloudClientConfig JumpCloud manager client configurations.
|
||||
type JumpCloudClientConfig struct {
|
||||
APIToken string
|
||||
}
|
||||
|
||||
// JumpCloudCredentials JumpCloud authentication information.
|
||||
type JumpCloudCredentials struct {
|
||||
clientConfig JumpCloudClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// NewJumpCloudManager creates a new instance of the JumpCloudManager.
|
||||
func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppMetrics) (*JumpCloudManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.APIToken == "" {
|
||||
return nil, fmt.Errorf("jumpCloud IdP configuration is incomplete, ApiToken is missing")
|
||||
}
|
||||
|
||||
client := v1.NewAPIClient(v1.NewConfiguration())
|
||||
credentials := &JumpCloudCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &JumpCloudManager{
|
||||
client: client,
|
||||
apiToken: config.APIToken,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the JumpCloud user API.
|
||||
func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
func (jm *JumpCloudManager) authenticationContext() context.Context {
|
||||
return context.WithValue(context.Background(), v1.ContextAPIKey, v1.APIKey{
|
||||
Key: jm.apiToken,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from JumpCloud via ID.
|
||||
func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
userData := parseJumpCloudUser(user)
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, user := range userList.Results {
|
||||
userData := parseJumpCloudUser(user)
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, user := range userList.Results {
|
||||
userData := parseJumpCloudUser(user)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
|
||||
func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
searchFilter := map[string]interface{}{
|
||||
"searchFilter": map[string]interface{}{
|
||||
"filter": []string{email},
|
||||
"fields": []string{"email"},
|
||||
},
|
||||
}
|
||||
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, searchFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
usersData := make([]*UserData, 0)
|
||||
for _, user := range userList.Results {
|
||||
usersData = append(usersData, parseJumpCloudUser(user))
|
||||
}
|
||||
|
||||
return usersData, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from jumpCloud directory
|
||||
func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
|
||||
authCtx := jm.authenticationContext()
|
||||
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseJumpCloudUser parse JumpCloud system user returned from API V1 to UserData.
|
||||
func parseJumpCloudUser(user v1.Systemuserreturn) *UserData {
|
||||
names := []string{user.Firstname, user.Middlename, user.Lastname}
|
||||
return &UserData{
|
||||
Email: user.Email,
|
||||
Name: strings.Join(names, " "),
|
||||
ID: user.Id,
|
||||
}
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewJumpCloudManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig JumpCloudClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := JumpCloudClientConfig{
|
||||
APIToken: "test123",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.APIToken = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing APIToken Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewJumpCloudManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,439 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// KeycloakManager keycloak manager client instance.
|
||||
type KeycloakManager struct {
|
||||
adminEndpoint string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// KeycloakClientConfig keycloak manager client configurations.
|
||||
type KeycloakClientConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AdminEndpoint string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// KeycloakCredentials keycloak authentication information.
|
||||
type KeycloakCredentials struct {
|
||||
clientConfig KeycloakClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// keycloakUserAttributes holds additional user data fields.
|
||||
type keycloakUserAttributes map[string][]string
|
||||
|
||||
// keycloakProfile represents a keycloak user profile response.
|
||||
type keycloakProfile struct {
|
||||
ID string `json:"id"`
|
||||
CreatedTimestamp int64 `json:"createdTimestamp"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Attributes keycloakUserAttributes `json:"attributes"`
|
||||
}
|
||||
|
||||
// NewKeycloakManager creates a new instance of the KeycloakManager.
|
||||
func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, clientID is missing")
|
||||
}
|
||||
|
||||
if config.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, ClientSecret is missing")
|
||||
}
|
||||
|
||||
if config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, TokenEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.AdminEndpoint == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, AdminEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
credentials := &KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &KeycloakManager{
|
||||
adminEndpoint: config.AdminEndpoint,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from keycloak.
|
||||
func (kc *KeycloakCredentials) jwtStillValid() bool {
|
||||
return !kc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(kc.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (kc *KeycloakCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kc.clientConfig.ClientID)
|
||||
data.Set("client_secret", kc.clientConfig.ClientSecret)
|
||||
data.Set("grant_type", kc.clientConfig.GrantType)
|
||||
|
||||
payload := strings.NewReader(data.Encode())
|
||||
req, err := http.NewRequest(http.MethodPost, kc.clientConfig.TokenEndpoint, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for keycloak idp manager")
|
||||
|
||||
resp, err := kc.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if kc.appMetrics != nil {
|
||||
kc.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unable to get keycloak token, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = kc.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = kc.helper.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the keycloak Management API.
|
||||
func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
kc.mux.Lock()
|
||||
defer kc.mux.Unlock()
|
||||
|
||||
if kc.appMetrics != nil {
|
||||
kc.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// reuse the token without requesting a new one if it is not expired,
|
||||
// and if expiry time is sufficient time available to make a request.
|
||||
if kc.jwtStillValid() {
|
||||
return kc.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := kc.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return kc.jwtToken, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
jwtToken, err := kc.parseRequestJWTResponse(resp.Body)
|
||||
if err != nil {
|
||||
return kc.jwtToken, err
|
||||
}
|
||||
|
||||
kc.jwtToken = jwtToken
|
||||
|
||||
return kc.jwtToken, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in keycloak Idp and sends an invite.
|
||||
func (km *KeycloakManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (km *KeycloakManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("email", email)
|
||||
q.Add("exact", "true")
|
||||
|
||||
body, err := km.get(ctx, "users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
profiles := make([]keycloakProfile, 0)
|
||||
err = km.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles {
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (km *KeycloakManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) {
|
||||
body, err := km.get(ctx, "users/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var profile keycloakProfile
|
||||
err = km.helper.Unmarshal(body, &profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return profile.userData(), nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given account profile.
|
||||
func (km *KeycloakManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles {
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (km *KeycloakManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, profile := range profiles {
|
||||
userData := profile.userData()
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (km *KeycloakManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (km *KeycloakManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Keycloak by user ID.
|
||||
func (km *KeycloakManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
jwtToken, err := km.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID))
|
||||
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
resp, err := km.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close() // nolint
|
||||
|
||||
// In the docs, they specified 200, but in the endpoints, they return 204
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloakProfile, error) {
|
||||
totalUsers, err := km.totalUsersCount(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := url.Values{}
|
||||
q.Add("max", fmt.Sprint(*totalUsers))
|
||||
|
||||
body, err := km.get(ctx, "users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profiles := make([]keycloakProfile, 0)
|
||||
err = km.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (km *KeycloakManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := km.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/%s?%s", km.adminEndpoint, resource, q.Encode())
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
resp, err := km.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// totalUsersCount returns the total count of all user created.
|
||||
// Used when fetching all registered accounts with pagination.
|
||||
func (km *KeycloakManager) totalUsersCount(ctx context.Context) (*int, error) {
|
||||
body, err := km.get(ctx, "users/count", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
count, err := strconv.Atoi(string(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &count, nil
|
||||
}
|
||||
|
||||
// userData construct user data from keycloak profile.
|
||||
func (kp keycloakProfile) userData() *UserData {
|
||||
return &UserData{
|
||||
Email: kp.Email,
|
||||
Name: kp.Username,
|
||||
ID: kp.ID,
|
||||
}
|
||||
}
|
||||
@@ -1,310 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewKeycloakManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig KeycloakClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := KeycloakClientConfig{
|
||||
ClientID: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
AdminEndpoint: "https://localhost:8080/auth/admin/realms/test123",
|
||||
TokenEndpoint: "https://localhost:8080/auth/realms/test123/protocol/openid-connect/token",
|
||||
GrantType: "client_credentials",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.ClientID = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing ClientID Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase3Config := defaultTestConfig
|
||||
testCase3Config.ClientSecret = ""
|
||||
|
||||
testCase3 := test{
|
||||
name: "Missing ClientSecret Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase4Config := defaultTestConfig
|
||||
testCase4Config.TokenEndpoint = ""
|
||||
|
||||
testCase4 := test{
|
||||
name: "Missing TokenEndpoint Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase5Config := defaultTestConfig
|
||||
testCase5Config.GrantType = ""
|
||||
|
||||
testCase5 := test{
|
||||
name: "Missing GrantType Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakRequestJWTToken(t *testing.T) {
|
||||
|
||||
type requestJWTTokenTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||
name: "Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
}
|
||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||
name: "Request Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputRespBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputRespBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err, "unable to read the response body")
|
||||
|
||||
jwtToken := JWTToken{}
|
||||
err = testCase.helper.Unmarshal(body, &jwtToken)
|
||||
assert.NoError(t, err, "unable to parse the json input")
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakParseRequestJWTResponse(t *testing.T) {
|
||||
type parseRequestJWTResponseTest struct {
|
||||
name string
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedToken string
|
||||
expectedExpiresIn int
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 100
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||
name: "Parse Good JWT Body",
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
expectedExpiresIn: exp,
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "no error was expected",
|
||||
}
|
||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||
name: "Parse Bad json JWT Body",
|
||||
inputRespBody: "",
|
||||
helper: JsonParser{},
|
||||
expectedToken: "",
|
||||
expectedExpiresIn: 0,
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "json error was expected",
|
||||
}
|
||||
|
||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakJwtStillValid(t *testing.T) {
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakAuthenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: nil,
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,52 +4,26 @@ import "context"
|
||||
|
||||
// MockIDP is a mock implementation of the IDP interface
|
||||
type MockIDP struct {
|
||||
UpdateUserAppMetadataFunc func(ctx context.Context, userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByIDFunc func(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccountFunc func(ctx context.Context, accountId string) ([]*UserData, error)
|
||||
GetAllAccountsFunc func(ctx context.Context) (map[string][]*UserData, error)
|
||||
CreateUserFunc func(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
|
||||
InviteUserByIDFunc func(ctx context.Context, userID string) error
|
||||
DeleteUserFunc func(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata is a mock implementation of the IDP interface UpdateUserAppMetadata method
|
||||
func (m *MockIDP) UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error {
|
||||
if m.UpdateUserAppMetadataFunc != nil {
|
||||
return m.UpdateUserAppMetadataFunc(ctx, userId, appMetadata)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
|
||||
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
|
||||
if m.GetUserDataByIDFunc != nil {
|
||||
return m.GetUserDataByIDFunc(ctx, userId, appMetadata)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAccount is a mock implementation of the IDP interface GetAccount method
|
||||
func (m *MockIDP) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
|
||||
if m.GetAccountFunc != nil {
|
||||
return m.GetAccountFunc(ctx, accountId)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts is a mock implementation of the IDP interface GetAllAccounts method
|
||||
func (m *MockIDP) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
if m.GetAllAccountsFunc != nil {
|
||||
return m.GetAllAccountsFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
CreateUserFunc func(ctx context.Context, email, name string) (*UserData, error)
|
||||
GetUserDataByIDFunc func(ctx context.Context, userId string) (*UserData, error)
|
||||
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
|
||||
GetAllUsersFunc func(ctx context.Context) ([]*UserData, error)
|
||||
InviteUserByIDFunc func(ctx context.Context, userID string) error
|
||||
DeleteUserFunc func(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// CreateUser is a mock implementation of the IDP interface CreateUser method
|
||||
func (m *MockIDP) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
func (m *MockIDP) CreateUser(ctx context.Context, email, name string) (*UserData, error) {
|
||||
if m.CreateUserFunc != nil {
|
||||
return m.CreateUserFunc(ctx, email, name, accountID, invitedByEmail)
|
||||
return m.CreateUserFunc(ctx, email, name)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
|
||||
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string) (*UserData, error) {
|
||||
if m.GetUserDataByIDFunc != nil {
|
||||
return m.GetUserDataByIDFunc(ctx, userId)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
@@ -62,6 +36,14 @@ func (m *MockIDP) GetUserByEmail(ctx context.Context, email string) ([]*UserData
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAllUsers is a mock implementation of the IDP interface GetAllUsers method
|
||||
func (m *MockIDP) GetAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
if m.GetAllUsersFunc != nil {
|
||||
return m.GetAllUsersFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// InviteUserByID is a mock implementation of the IDP interface InviteUserByID method
|
||||
func (m *MockIDP) InviteUserByID(ctx context.Context, userID string) error {
|
||||
if m.InviteUserByIDFunc != nil {
|
||||
|
||||
@@ -1,306 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/okta/okta-sdk-golang/v2/okta"
|
||||
"github.com/okta/okta-sdk-golang/v2/okta/query"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// OktaManager okta manager client instance.
|
||||
type OktaManager struct {
|
||||
client *okta.Client
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// OktaClientConfig okta manager client configurations.
|
||||
type OktaClientConfig struct {
|
||||
APIToken string
|
||||
Issuer string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// OktaCredentials okta authentication information.
|
||||
type OktaCredentials struct {
|
||||
clientConfig OktaClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// NewOktaManager creates a new instance of the OktaManager.
|
||||
func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*OktaManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
config.Issuer = baseURL(config.Issuer)
|
||||
|
||||
if config.APIToken == "" {
|
||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, APIToken is missing")
|
||||
}
|
||||
|
||||
if config.Issuer == "" {
|
||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, Issuer is missing")
|
||||
}
|
||||
|
||||
if config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, TokenEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
_, client, err := okta.NewClient(context.Background(),
|
||||
okta.WithOrgUrl(config.Issuer),
|
||||
okta.WithToken(config.APIToken),
|
||||
okta.WithHttpClientPtr(httpClient),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials := &OktaCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &OktaManager{
|
||||
client: client,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the okta user API.
|
||||
func (oc *OktaCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in okta Idp and sends an invitation.
|
||||
func (om *OktaManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (om *OktaManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
user, resp, err := om.client.User.GetUser(context.Background(), userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
userData, err := parseOktaUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (om *OktaManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
user, resp, err := om.client.User.GetUser(context.Background(), url.QueryEscape(email))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
||||
}
|
||||
|
||||
userData, err := parseOktaUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users := make([]*UserData, 0)
|
||||
users = append(users, userData)
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (om *OktaManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := om.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
for index, user := range users {
|
||||
user.AppMetadata.WTAccountID = accountID
|
||||
users[index] = user
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (om *OktaManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
users, err := om.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in an Okta account.
|
||||
func (om *OktaManager) getAllUsers() ([]*UserData, error) {
|
||||
qp := query.NewQueryParams(query.WithLimit(200))
|
||||
userList, resp, err := om.client.User.ListUsers(context.Background(), qp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
for resp.HasNextPage() {
|
||||
paginatedUsers := make([]*okta.User, 0)
|
||||
resp, err = resp.Next(context.Background(), &paginatedUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
userList = append(userList, paginatedUsers...)
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0, len(userList))
|
||||
for _, user := range userList {
|
||||
userData, err := parseOktaUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (om *OktaManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (om *OktaManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Okta
|
||||
func (om *OktaManager) DeleteUser(_ context.Context, userID string) error {
|
||||
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseOktaUser parse okta user to UserData.
|
||||
func parseOktaUser(user *okta.User) (*UserData, error) {
|
||||
var oktaUser struct {
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return nil, fmt.Errorf("invalid okta user")
|
||||
}
|
||||
|
||||
if user.Profile != nil {
|
||||
helper := JsonParser{}
|
||||
buf, err := helper.Marshal(*user.Profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = helper.Unmarshal(buf, &oktaUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &UserData{
|
||||
Email: oktaUser.Email,
|
||||
Name: strings.Join([]string{oktaUser.FirstName, oktaUser.LastName}, " "),
|
||||
ID: user.Id,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/okta/okta-sdk-golang/v2/okta"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseOktaUser(t *testing.T) {
|
||||
type parseOktaUserTest struct {
|
||||
name string
|
||||
inputProfile *okta.User
|
||||
expectedUserData *UserData
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
}
|
||||
|
||||
parseOktaTestCase1 := parseOktaUserTest{
|
||||
name: "Good Request",
|
||||
inputProfile: &okta.User{
|
||||
Id: "123",
|
||||
Profile: &okta.UserProfile{
|
||||
"email": "test@example.com",
|
||||
"firstName": "John",
|
||||
"lastName": "Doe",
|
||||
},
|
||||
},
|
||||
expectedUserData: &UserData{
|
||||
Email: "test@example.com",
|
||||
Name: "John Doe",
|
||||
ID: "123",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "456",
|
||||
},
|
||||
},
|
||||
assertErrFunc: assert.NoError,
|
||||
}
|
||||
|
||||
parseOktaTestCase2 := parseOktaUserTest{
|
||||
name: "Invalid okta user",
|
||||
inputProfile: nil,
|
||||
expectedUserData: nil,
|
||||
assertErrFunc: assert.Error,
|
||||
}
|
||||
|
||||
for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
userData, err := parseOktaUser(testCase.inputProfile)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFunc)
|
||||
|
||||
if err == nil {
|
||||
assert.True(t, userDataEqual(testCase.expectedUserData, userData), "user data should match")
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// userDataEqual helper function to compare UserData structs for equality.
|
||||
func userDataEqual(a, b *UserData) bool {
|
||||
if a.Email != b.Email || a.Name != b.Name || a.ID != b.ID {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,384 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
type PocketIdManager struct {
|
||||
managementEndpoint string
|
||||
apiToken string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
type pocketIdCustomClaimDto struct {
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type pocketIdUserDto struct {
|
||||
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
|
||||
Disabled bool `json:"disabled"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstName"`
|
||||
ID string `json:"id"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
LastName string `json:"lastName"`
|
||||
LdapID string `json:"ldapId"`
|
||||
Locale string `json:"locale"`
|
||||
UserGroups []pocketIdUserGroupDto `json:"userGroups"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type pocketIdUserCreateDto struct {
|
||||
Disabled bool `json:"disabled,omitempty"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstName"`
|
||||
IsAdmin bool `json:"isAdmin,omitempty"`
|
||||
LastName string `json:"lastName,omitempty"`
|
||||
Locale string `json:"locale,omitempty"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type pocketIdPaginatedUserDto struct {
|
||||
Data []pocketIdUserDto `json:"data"`
|
||||
Pagination pocketIdPaginationDto `json:"pagination"`
|
||||
}
|
||||
|
||||
type pocketIdPaginationDto struct {
|
||||
CurrentPage int `json:"currentPage"`
|
||||
ItemsPerPage int `json:"itemsPerPage"`
|
||||
TotalItems int `json:"totalItems"`
|
||||
TotalPages int `json:"totalPages"`
|
||||
}
|
||||
|
||||
func (p *pocketIdUserDto) userData() *UserData {
|
||||
return &UserData{
|
||||
Email: p.Email,
|
||||
Name: p.DisplayName,
|
||||
ID: p.ID,
|
||||
AppMetadata: AppMetadata{},
|
||||
}
|
||||
}
|
||||
|
||||
type pocketIdUserGroupDto struct {
|
||||
CreatedAt string `json:"createdAt"`
|
||||
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
|
||||
FriendlyName string `json:"friendlyName"`
|
||||
ID string `json:"id"`
|
||||
LdapID string `json:"ldapId"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ManagementEndpoint == "" {
|
||||
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.APIToken == "" {
|
||||
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing")
|
||||
}
|
||||
|
||||
credentials := &PocketIdCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &PocketIdManager{
|
||||
managementEndpoint: config.ManagementEndpoint,
|
||||
apiToken: config.APIToken,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) {
|
||||
var MethodsWithBody = []string{http.MethodPost, http.MethodPut}
|
||||
if !slices.Contains(MethodsWithBody, method) && body != "" {
|
||||
return nil, fmt.Errorf("Body provided to unsupported method: %s", method)
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource)
|
||||
if query != nil {
|
||||
reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode())
|
||||
}
|
||||
var req *http.Request
|
||||
var err error
|
||||
if body != "" {
|
||||
req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body))
|
||||
} else {
|
||||
req, err = http.NewRequestWithContext(ctx, method, reqURL, nil)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("X-API-KEY", p.apiToken)
|
||||
|
||||
if body != "" {
|
||||
req.Header.Add("content-type", "application/json")
|
||||
req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength))
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// getAllUsersPaginated fetches all users from PocketID API using pagination
|
||||
func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) {
|
||||
var allUsers []pocketIdUserDto
|
||||
currentPage := 1
|
||||
|
||||
for {
|
||||
params := url.Values{}
|
||||
// Copy existing search parameters
|
||||
for key, values := range searchParams {
|
||||
params[key] = values
|
||||
}
|
||||
|
||||
params.Set("pagination[limit]", "100")
|
||||
params.Set("pagination[page]", fmt.Sprintf("%d", currentPage))
|
||||
|
||||
body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var profiles pocketIdPaginatedUserDto
|
||||
err = p.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allUsers = append(allUsers, profiles.Data...)
|
||||
|
||||
// Check if we've reached the last page
|
||||
if currentPage >= profiles.Pagination.TotalPages {
|
||||
break
|
||||
}
|
||||
|
||||
currentPage++
|
||||
}
|
||||
|
||||
return allUsers, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
|
||||
body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var user pocketIdUserDto
|
||||
err = p.helper.Unmarshal(body, &user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData := user.userData()
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
|
||||
// Get all users using pagination
|
||||
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range allUsers {
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata.WTAccountID = accountId
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
// Get all users using pagination
|
||||
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, profile := range allUsers {
|
||||
userData := profile.userData()
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
firstLast := strings.Split(name, " ")
|
||||
|
||||
createUser := pocketIdUserCreateDto{
|
||||
Disabled: false,
|
||||
DisplayName: name,
|
||||
Email: email,
|
||||
FirstName: firstLast[0],
|
||||
LastName: firstLast[1],
|
||||
Username: firstLast[0] + "." + firstLast[1],
|
||||
}
|
||||
payload, err := p.helper.Marshal(createUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var newUser pocketIdUserDto
|
||||
err = p.helper.Unmarshal(body, &newUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
var pending bool = true
|
||||
ret := &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: newUser.ID,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &pending,
|
||||
WTInvitedBy: invitedByEmail,
|
||||
},
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
params := url.Values{
|
||||
// This value a
|
||||
"search": []string{email},
|
||||
}
|
||||
body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
var profiles struct{ data []pocketIdUserDto }
|
||||
err = p.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles.data {
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error {
|
||||
_, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
_, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ Manager = (*PocketIdManager)(nil)
|
||||
|
||||
type PocketIdClientConfig struct {
|
||||
APIToken string
|
||||
ManagementEndpoint string
|
||||
}
|
||||
|
||||
type PocketIdCredentials struct {
|
||||
clientConfig PocketIdClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
var _ ManagerCredentials = (*PocketIdCredentials)(nil)
|
||||
|
||||
func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
@@ -1,137 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewPocketIdManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig PocketIdClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := PocketIdClientConfig{
|
||||
APIToken: "api_token",
|
||||
ManagementEndpoint: "http://localhost",
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
},
|
||||
{
|
||||
name: "Missing ManagementEndpoint",
|
||||
inputConfig: PocketIdClientConfig{
|
||||
APIToken: defaultTestConfig.APIToken,
|
||||
ManagementEndpoint: "",
|
||||
},
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
},
|
||||
{
|
||||
name: "Missing APIToken",
|
||||
inputConfig: PocketIdClientConfig{
|
||||
APIToken: "",
|
||||
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
|
||||
},
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
|
||||
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPocketID_GetUserDataByID(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
md := AppMetadata{WTAccountID: "acc1"}
|
||||
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "u1", got.ID)
|
||||
assert.Equal(t, "user1@example.com", got.Email)
|
||||
assert.Equal(t, "User One", got.Name)
|
||||
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
|
||||
}
|
||||
|
||||
func TestPocketID_GetAccount_WithPagination(t *testing.T) {
|
||||
// Single page response with two users
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
users, err := mgr.GetAccount(context.Background(), "accX")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, 2)
|
||||
assert.Equal(t, "u1", users[0].ID)
|
||||
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
|
||||
assert.Equal(t, "u2", users[1].ID)
|
||||
}
|
||||
|
||||
func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
accounts, err := mgr.GetAllAccounts(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, accounts[UnsetAccountID], 2)
|
||||
}
|
||||
|
||||
func TestPocketID_CreateUser(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "newid", ud.ID)
|
||||
assert.Equal(t, "new@example.com", ud.Email)
|
||||
assert.Equal(t, "New User", ud.Name)
|
||||
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
|
||||
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
|
||||
assert.True(t, *ud.AppMetadata.WTPendingInvite)
|
||||
}
|
||||
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
|
||||
}
|
||||
|
||||
func TestPocketID_InviteAndDeleteUser(t *testing.T) {
|
||||
// Same mock for both calls; returns OK with empty JSON
|
||||
client := &mockHTTPClient{code: 200, resBody: `{}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
err = mgr.InviteUserByID(context.Background(), "u1")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.DeleteUser(context.Background(), "u1")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
47
management/server/idp/test_util_test.go
Normal file
47
management/server/idp/test_util_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockHTTPClient is a mock implementation of ManagerHTTPClient for testing
|
||||
type mockHTTPClient struct {
|
||||
code int
|
||||
resBody string
|
||||
reqBody string
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
if c.err != nil {
|
||||
return nil, c.err
|
||||
}
|
||||
|
||||
if req.Body != nil {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
c.reqBody = string(body)
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: c.code,
|
||||
Body: io.NopCloser(bytes.NewReader([]byte(c.resBody))),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newTestJWT creates a test JWT token with the given expiration time in seconds
|
||||
func newTestJWT(t *testing.T, expiresIn int) string {
|
||||
t.Helper()
|
||||
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
|
||||
exp := time.Now().Add(time.Duration(expiresIn) * time.Second).Unix()
|
||||
payload := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf(`{"exp":%d}`, exp)))
|
||||
signature := base64.RawURLEncoding.EncodeToString([]byte("test-signature"))
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
@@ -319,14 +319,19 @@ func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel.
|
||||
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
// Note: Authorization data (account membership, invite status) is stored in NetBird's DB, not in the IdP.
|
||||
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name string) (*UserData, error) {
|
||||
firstLast := strings.SplitN(name, " ", 2)
|
||||
lastName := firstLast[0]
|
||||
if len(firstLast) > 1 {
|
||||
lastName = firstLast[1]
|
||||
}
|
||||
|
||||
var addUser = map[string]any{
|
||||
"userName": email,
|
||||
"profile": map[string]string{
|
||||
"firstName": firstLast[0],
|
||||
"lastName": firstLast[0],
|
||||
"lastName": lastName,
|
||||
"displayName": name,
|
||||
},
|
||||
"email": map[string]any{
|
||||
@@ -357,18 +362,11 @@ func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var pending bool = true
|
||||
ret := &UserData{
|
||||
return &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: newUser.UserId,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &pending,
|
||||
WTInvitedBy: invitedByEmail,
|
||||
},
|
||||
}
|
||||
return ret, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
@@ -413,7 +411,7 @@ func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from zitadel via ID.
|
||||
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string) (*UserData, error) {
|
||||
body, err := zm.get(ctx, "users/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -429,43 +427,12 @@ func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, ap
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData := profile.User.userData()
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
return profile.User.userData(), nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (zm *ZitadelManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
body, err := zm.post(ctx, "users/_search", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
var profiles struct{ Result []zitadelProfile }
|
||||
err = zm.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles.Result {
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
// GetAllUsers returns all users from the IdP.
|
||||
// Used for cache warming - NetBird matches these against its own user database.
|
||||
func (zm *ZitadelManager) GetAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
body, err := zm.post(ctx, "users/_search", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -481,19 +448,12 @@ func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*Use
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
users := make([]*UserData, 0, len(profiles.Result))
|
||||
for _, profile := range profiles.Result {
|
||||
userData := profile.userData()
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
// Metadata values are base64 encoded.
|
||||
func (zm *ZitadelManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
return users, nil
|
||||
}
|
||||
|
||||
type inviteUserRequest struct {
|
||||
|
||||
@@ -288,16 +288,14 @@ func TestZitadelAuthenticate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestZitadelProfile(t *testing.T) {
|
||||
type azureProfileTest struct {
|
||||
type zitadelProfileTest struct {
|
||||
name string
|
||||
invite bool
|
||||
inputProfile zitadelProfile
|
||||
expectedUserData UserData
|
||||
}
|
||||
|
||||
azureProfileTestCase1 := azureProfileTest{
|
||||
name: "User Request",
|
||||
invite: false,
|
||||
zitadelProfileTestCase1 := zitadelProfileTest{
|
||||
name: "User Request",
|
||||
inputProfile: zitadelProfile{
|
||||
ID: "test1",
|
||||
State: "USER_STATE_ACTIVE",
|
||||
@@ -322,15 +320,11 @@ func TestZitadelProfile(t *testing.T) {
|
||||
ID: "test1",
|
||||
Name: "ZITADEL Admin",
|
||||
Email: "test1@mail.com",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
azureProfileTestCase2 := azureProfileTest{
|
||||
name: "Service User Request",
|
||||
invite: true,
|
||||
zitadelProfileTestCase2 := zitadelProfileTest{
|
||||
name: "Service User Request",
|
||||
inputProfile: zitadelProfile{
|
||||
ID: "test2",
|
||||
State: "USER_STATE_ACTIVE",
|
||||
@@ -345,15 +339,11 @@ func TestZitadelProfile(t *testing.T) {
|
||||
ID: "test2",
|
||||
Name: "machine",
|
||||
Email: "machine",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2} {
|
||||
for _, testCase := range []zitadelProfileTest{zitadelProfileTestCase1, zitadelProfileTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
||||
userData := testCase.inputProfile.userData()
|
||||
|
||||
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -487,7 +486,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
||||
if am.idpManager != nil {
|
||||
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||
userdata, err := am.idpManager.GetUserDataByID(ctx, userID)
|
||||
if err == nil && userdata != nil {
|
||||
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
||||
}
|
||||
|
||||
9974
management/server/types/testdata/networkmap_golden.json
vendored
Normal file
9974
management/server/types/testdata/networkmap_golden.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9974
management/server/types/testdata/networkmap_golden_new.json
vendored
Normal file
9974
management/server/types/testdata/networkmap_golden_new.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9862
management/server/types/testdata/networkmap_golden_new_with_deleted_router.json
vendored
Normal file
9862
management/server/types/testdata/networkmap_golden_new_with_deleted_router.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
10086
management/server/types/testdata/networkmap_golden_new_with_onpeeradded.json
vendored
Normal file
10086
management/server/types/testdata/networkmap_golden_new_with_onpeeradded.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
10086
management/server/types/testdata/networkmap_golden_new_with_onpeeradded_router.json
vendored
Normal file
10086
management/server/types/testdata/networkmap_golden_new_with_onpeeradded_router.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9862
management/server/types/testdata/networkmap_golden_new_with_onpeerdeleted.json
vendored
Normal file
9862
management/server/types/testdata/networkmap_golden_new_with_onpeerdeleted.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9862
management/server/types/testdata/networkmap_golden_with_deleted_peer.json
vendored
Normal file
9862
management/server/types/testdata/networkmap_golden_with_deleted_peer.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9862
management/server/types/testdata/networkmap_golden_with_deleted_router_peer.json
vendored
Normal file
9862
management/server/types/testdata/networkmap_golden_with_deleted_router_peer.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
10086
management/server/types/testdata/networkmap_golden_with_new_peer.json
vendored
Normal file
10086
management/server/types/testdata/networkmap_golden_with_new_peer.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
10086
management/server/types/testdata/networkmap_golden_with_new_router.json
vendored
Normal file
10086
management/server/types/testdata/networkmap_golden_with_new_router.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@@ -85,6 +85,8 @@ type User struct {
|
||||
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||
Blocked bool
|
||||
// PendingInvite indicates whether the user has accepted their invite and logged in
|
||||
PendingInvite bool
|
||||
// PendingApproval indicates whether the user requires approval before being activated
|
||||
PendingApproval bool
|
||||
// LastLogin is the last time the user logged in to IdP
|
||||
@@ -162,7 +164,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
}
|
||||
|
||||
userStatus := UserStatusActive
|
||||
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
|
||||
if u.PendingInvite {
|
||||
userStatus = UserStatusInvited
|
||||
}
|
||||
|
||||
@@ -199,6 +201,7 @@ func (u *User) Copy() *User {
|
||||
ServiceUserName: u.ServiceUserName,
|
||||
PATs: pats,
|
||||
Blocked: u.Blocked,
|
||||
PendingInvite: u.PendingInvite,
|
||||
PendingApproval: u.PendingApproval,
|
||||
LastLogin: u.LastLogin,
|
||||
CreatedAt: u.CreatedAt,
|
||||
|
||||
@@ -117,6 +117,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
Issued: invite.Issued,
|
||||
IntegrationReference: invite.IntegrationReference,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
PendingInvite: true, // User hasn't accepted invite yet
|
||||
}
|
||||
|
||||
if err = am.Store.SaveUser(ctx, newUser); err != nil {
|
||||
@@ -169,7 +170,7 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
|
||||
}
|
||||
|
||||
return am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviterUser.Email)
|
||||
return am.idpManager.CreateUser(ctx, invite.Email, invite.Name)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||
@@ -285,8 +286,8 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// check if the user is already registered with this ID
|
||||
user, err := am.lookupUserInCache(ctx, targetUserID, accountID)
|
||||
// Get user from NetBird's database (source of truth for authorization data)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -295,18 +296,17 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
|
||||
return status.Errorf(status.NotFound, "user account %s doesn't exist", targetUserID)
|
||||
}
|
||||
|
||||
// check if user account is already invited and account is not activated
|
||||
pendingInvite := user.AppMetadata.WTPendingInvite
|
||||
if pendingInvite == nil || !*pendingInvite {
|
||||
// Check if user is still pending invite (hasn't activated their account)
|
||||
if !user.PendingInvite {
|
||||
return status.Errorf(status.PreconditionFailed, "can't invite a user with an activated NetBird account")
|
||||
}
|
||||
|
||||
err = am.idpManager.InviteUserByID(ctx, user.ID)
|
||||
err = am.idpManager.InviteUserByID(ctx, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, user.ID, accountID, activity.UserInvited, nil)
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserInvited, nil)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1011,17 +1011,15 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
|
||||
func (am *DefaultAccountManager) deleteUserFromIDP(ctx context.Context, targetUserID, accountID string) error {
|
||||
if am.userDeleteFromIDPEnabled {
|
||||
log.WithContext(ctx).Debugf("user %s deleted from IdP", targetUserID)
|
||||
log.WithContext(ctx).Debugf("deleting user %s from IdP", targetUserID)
|
||||
err := am.idpManager.DeleteUser(ctx, targetUserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err)
|
||||
}
|
||||
} else {
|
||||
err := am.idpManager.UpdateUserAppMetadata(ctx, targetUserID, idp.AppMetadata{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove user %s app metadata in IdP: %s", targetUserID, err)
|
||||
}
|
||||
}
|
||||
// Note: If userDeleteFromIDPEnabled is false, the user remains in IdP but is removed from NetBird.
|
||||
// This allows the user to re-authenticate if they're re-invited later.
|
||||
|
||||
err := am.removeUserFromCache(ctx, accountID, targetUserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("remove user from account (%q) cache failed with error: %v", accountID, err)
|
||||
@@ -1095,7 +1093,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
||||
if !isNil(am.idpManager) {
|
||||
// Delete if the user already exists in the IdP. Necessary in cases where a user account
|
||||
// was created where a user account was provisioned but the user did not sign in
|
||||
_, err := am.idpManager.GetUserDataByID(ctx, targetUserInfo.ID, idp.AppMetadata{WTAccountID: accountID})
|
||||
_, err := am.idpManager.GetUserDataByID(ctx, targetUserInfo.ID)
|
||||
if err == nil {
|
||||
err = am.deleteUserFromIDP(ctx, targetUserInfo.ID, accountID)
|
||||
if err != nil {
|
||||
|
||||
@@ -563,7 +563,7 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
}
|
||||
|
||||
idpMock := idp.MockIDP{
|
||||
CreateUserFunc: func(_ context.Context, email, name, accountID, invitedByEmail string) (*idp.UserData, error) {
|
||||
CreateUserFunc: func(_ context.Context, email, name string) (*idp.UserData, error) {
|
||||
newData := &idp.UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
@@ -574,7 +574,7 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
|
||||
return newData, nil
|
||||
},
|
||||
GetAccountFunc: func(_ context.Context, accountId string) ([]*idp.UserData, error) {
|
||||
GetAllUsersFunc: func(_ context.Context) ([]*idp.UserData, error) {
|
||||
return mockData, nil
|
||||
},
|
||||
}
|
||||
@@ -1068,7 +1068,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
idpManager: &idp.GoogleWorkspaceManager{}, // empty manager
|
||||
idpManager: &idp.MockIDP{}, // empty manager
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user