mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 00:06:38 +00:00
refactor(idp): make NetBird single source of truth for authorization
Remove duplicate authorization data from Zitadel IdP. NetBird now stores all authorization data (account membership, invite status, roles) locally, while Zitadel only stores identity information (email, name, credentials). Changes: - Add PendingInvite field to User struct to track invite status locally - Simplify IdP Manager interface: remove metadata methods, add GetAllUsers - Update cache warming to match IdP users against NetBird DB - Remove addAccountIDToIDPAppMeta and all wt_* metadata writes - Delete legacy IdP managers (Auth0, Azure, Keycloak, Okta, Google Workspace, JumpCloud, Authentik, PocketId) - only Zitadel supported
This commit is contained in:
352
docs/plans/self-service-auth-plan.mdx
Normal file
352
docs/plans/self-service-auth-plan.mdx
Normal file
@@ -0,0 +1,352 @@
|
||||
# Zitadel Integration Plan
|
||||
|
||||
This plan is split into stages. Each stage is self-contained and can be implemented and tested independently.
|
||||
|
||||
---
|
||||
|
||||
## Stage 1: Infrastructure Setup [COMPLETED]
|
||||
|
||||
**Goal**: Add Zitadel to docker-compose with first-boot initialization (single container, SQLite).
|
||||
|
||||
### Design Principles
|
||||
- **KISS**: Single Zitadel container with SQLite (no separate PostgreSQL or init containers)
|
||||
- **DRY**: Leverage Zitadel's built-in first-instance configuration via environment variables
|
||||
- **Clean**: Credentials generated in configure.sh and printed directly
|
||||
|
||||
### Files Created
|
||||
|
||||
#### `infrastructure_files/Caddyfile.tmpl`
|
||||
Reverse proxy configuration routing:
|
||||
- `/api/*`, `/management.*` → Management server
|
||||
- `/relay*` → Relay
|
||||
- `/signalexchange.*`, `/signal*` → Signal
|
||||
- `/oauth/*`, `/oidc/*`, `/.well-known/*`, `/ui/*` → Zitadel
|
||||
- `/*` → Dashboard
|
||||
|
||||
### Files Modified
|
||||
|
||||
#### `infrastructure_files/docker-compose.yml.tmpl`
|
||||
Added single Zitadel service using SQLite:
|
||||
```yaml
|
||||
zitadel:
|
||||
image: ghcr.io/zitadel/zitadel:$ZITADEL_TAG
|
||||
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
||||
environment:
|
||||
- ZITADEL_DATABASE_SQLITE_PATH=/data/zitadel.db
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_USERNAME=$ZITADEL_ADMIN_USERNAME
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_PASSWORD=$ZITADEL_ADMIN_PASSWORD
|
||||
# ... service account config via env vars
|
||||
```
|
||||
|
||||
#### `infrastructure_files/management.json.tmpl`
|
||||
Updated IdpManagerConfig to default to Zitadel.
|
||||
|
||||
#### `infrastructure_files/base.setup.env`
|
||||
Added Zitadel-specific environment variables:
|
||||
- `ZITADEL_TAG` (default: v4.7.6)
|
||||
- `ZITADEL_MASTERKEY` (auto-generated)
|
||||
- `ZITADEL_ADMIN_USERNAME`, `ZITADEL_ADMIN_PASSWORD` (auto-generated)
|
||||
- `ZITADEL_EXTERNALSECURE`, `ZITADEL_EXTERNALPORT`, `ZITADEL_TLS_MODE`
|
||||
|
||||
#### `infrastructure_files/configure.sh`
|
||||
- Generates Zitadel masterkey if not set
|
||||
- Generates admin credentials if not set
|
||||
- Prints credentials to stdout during configuration
|
||||
|
||||
### Deliverable
|
||||
Running `./configure.sh && cd artifacts && docker compose up -d` starts full stack with Zitadel. Admin credentials printed during configuration.
|
||||
|
||||
---
|
||||
|
||||
## Stage 2: Simplify IdP Integration
|
||||
|
||||
**Goal**: Make NetBird the single source of truth for user authorization data. Zitadel handles authentication only.
|
||||
|
||||
### Design Principles
|
||||
- **Single source of truth**: NetBird DB stores all authorization data (roles, account membership, invite status)
|
||||
- **Clean separation**: Zitadel stores identity only (email, name, password)
|
||||
- **No duplicate data**: Remove `wt_account_id` and `wt_pending_invite` from IdP metadata
|
||||
|
||||
### Files to Modify
|
||||
|
||||
#### `management/server/types/user.go`
|
||||
Add `PendingInvite` field to User struct:
|
||||
```go
|
||||
type User struct {
|
||||
// ... existing fields ...
|
||||
PendingInvite bool // NEW: tracks if user has accepted invite
|
||||
}
|
||||
```
|
||||
|
||||
Update `ToUserInfo()` to use local field instead of IdP metadata:
|
||||
```go
|
||||
// Before
|
||||
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
|
||||
userStatus = UserStatusInvited
|
||||
}
|
||||
|
||||
// After
|
||||
if u.PendingInvite {
|
||||
userStatus = UserStatusInvited
|
||||
}
|
||||
```
|
||||
|
||||
#### `management/server/idp/idp.go`
|
||||
Simplify `Manager` interface:
|
||||
```go
|
||||
type Manager interface {
|
||||
// Simplified - no more accountID or invitedByEmail params
|
||||
CreateUser(ctx context.Context, email, name string) (*UserData, error)
|
||||
GetUserDataByID(ctx context.Context, userId string) (*UserData, error)
|
||||
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
|
||||
InviteUserByID(ctx context.Context, userID string) error
|
||||
DeleteUser(ctx context.Context, userID string) error
|
||||
}
|
||||
```
|
||||
|
||||
Remove:
|
||||
- `UpdateUserAppMetadata()` - no longer needed
|
||||
- `GetAccount()` - account membership now in NetBird DB
|
||||
- `GetAllAccounts()` - account membership now in NetBird DB
|
||||
- `AppMetadata` struct fields (`WTAccountID`, `WTPendingInvite`, `WTInvitedBy`)
|
||||
|
||||
#### `management/server/idp/zitadel.go`
|
||||
- Update `CreateUser()` to not set metadata
|
||||
- Remove `UpdateUserAppMetadata()` implementation
|
||||
- Simplify `GetUserDataByID()` to not expect metadata
|
||||
|
||||
#### `management/server/user.go`
|
||||
Update user creation flow:
|
||||
1. `inviteNewUser()`: Set `PendingInvite = true` when creating user
|
||||
2. On first login: Set `PendingInvite = false`
|
||||
|
||||
### Data Flow (After)
|
||||
|
||||
```
|
||||
Invite User:
|
||||
1. Admin invites user@example.com via NetBird UI
|
||||
2. NetBird calls Zitadel: CreateUser(email, name) // No metadata
|
||||
3. Zitadel creates user, sends invite email
|
||||
4. NetBird creates User in its DB:
|
||||
- Id = Zitadel user ID
|
||||
- AccountID = current account
|
||||
- Role = invited role
|
||||
- PendingInvite = true
|
||||
|
||||
First Login:
|
||||
1. User clicks invite, sets password in Zitadel
|
||||
2. User logs into NetBird Dashboard
|
||||
3. NetBird updates User in its DB:
|
||||
- PendingInvite = false
|
||||
- LastLogin = now
|
||||
```
|
||||
|
||||
### Deliverable
|
||||
- NetBird is single source of truth for authorization
|
||||
- Zitadel stores identity only (email, name, credentials)
|
||||
- No metadata sync issues between systems
|
||||
|
||||
---
|
||||
|
||||
## Stage 3: Remove Legacy IdP Managers
|
||||
|
||||
**Goal**: Remove all non-Zitadel IdP implementations.
|
||||
|
||||
### Files to Delete
|
||||
- `management/server/idp/auth0.go`
|
||||
- `management/server/idp/auth0_test.go`
|
||||
- `management/server/idp/azure.go`
|
||||
- `management/server/idp/azure_test.go`
|
||||
- `management/server/idp/keycloak.go`
|
||||
- `management/server/idp/keycloak_test.go`
|
||||
- `management/server/idp/okta.go`
|
||||
- `management/server/idp/okta_test.go`
|
||||
- `management/server/idp/google_workspace.go`
|
||||
- `management/server/idp/google_workspace_test.go`
|
||||
- `management/server/idp/jumpcloud.go`
|
||||
- `management/server/idp/jumpcloud_test.go`
|
||||
- `management/server/idp/authentik.go`
|
||||
- `management/server/idp/authentik_test.go`
|
||||
- `management/server/idp/pocketid.go`
|
||||
- `management/server/idp/pocketid_test.go`
|
||||
|
||||
### Files to Modify
|
||||
|
||||
#### `management/server/idp/idp.go`
|
||||
Simplify `NewManager()` to only support Zitadel:
|
||||
```go
|
||||
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
|
||||
switch strings.ToLower(config.ManagerType) {
|
||||
case "none", "":
|
||||
return nil, nil
|
||||
case "zitadel":
|
||||
return NewZitadelManager(...)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported IdP manager type: %s (only 'zitadel' is supported)", config.ManagerType)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### `management/server/idp/idp.go`
|
||||
Remove unused config structs:
|
||||
- `Auth0ClientConfig`
|
||||
- `AzureClientConfig`
|
||||
- `KeycloakClientConfig`
|
||||
- etc.
|
||||
|
||||
### Deliverable
|
||||
Only Zitadel IdP manager remains. Build passes. Tests pass.
|
||||
|
||||
---
|
||||
|
||||
## Stage 4: External IdP Connector API
|
||||
|
||||
**Goal**: API for users to add external IdPs (Okta, Google, LDAP) as Zitadel connectors.
|
||||
|
||||
### New Files
|
||||
|
||||
#### `management/server/idp/connectors.go`
|
||||
Wrapper for Zitadel IdP connector API:
|
||||
```go
|
||||
type Connector struct {
|
||||
ID string
|
||||
Name string
|
||||
Type string // "oidc", "ldap", "saml"
|
||||
Issuer string // for OIDC
|
||||
ClientID string // for OIDC
|
||||
}
|
||||
|
||||
type ConnectorManager interface {
|
||||
AddOIDCConnector(ctx, name, issuer, clientID, clientSecret string, scopes []string) (*Connector, error)
|
||||
AddLDAPConnector(ctx, name, host string, port int, baseDN, bindDN, bindPassword string) (*Connector, error)
|
||||
ListConnectors(ctx) ([]Connector, error)
|
||||
DeleteConnector(ctx, connectorID string) error
|
||||
}
|
||||
```
|
||||
|
||||
#### `management/server/http/handlers/idp_connectors_handler.go`
|
||||
REST API handlers:
|
||||
```
|
||||
POST /api/idp/connectors - Add connector
|
||||
GET /api/idp/connectors - List connectors
|
||||
DELETE /api/idp/connectors/{id} - Delete connector
|
||||
```
|
||||
|
||||
### Zitadel API Endpoints Used
|
||||
- `POST /management/v1/idps/generic_oidc` - Add OIDC connector
|
||||
- `POST /management/v1/idps/ldap` - Add LDAP connector
|
||||
- `GET /management/v1/idps` - List all IdPs
|
||||
- `DELETE /management/v1/idps/{id}` - Remove IdP
|
||||
|
||||
### Deliverable
|
||||
Admin users can add/remove external IdP connectors via NetBird API.
|
||||
|
||||
---
|
||||
|
||||
## Stage 5: User Role Permissions
|
||||
|
||||
**Goal**: Enforce admin/user permissions throughout the application.
|
||||
|
||||
### Files to Modify
|
||||
|
||||
#### `management/server/types/user.go`
|
||||
Simplify roles to:
|
||||
```go
|
||||
const (
|
||||
UserRoleAdmin UserRole = "admin"
|
||||
UserRoleUser UserRole = "user"
|
||||
)
|
||||
```
|
||||
|
||||
Keep `owner` as alias for `admin` for backwards compatibility.
|
||||
|
||||
#### Permission checks (various files)
|
||||
Update permission checks to use simplified role model:
|
||||
- Admin: Full access to all resources
|
||||
- User: View/manage own peers only, no user/policy management
|
||||
|
||||
### Permission Matrix
|
||||
| Resource | Admin | User |
|
||||
|----------|-------|------|
|
||||
| All peers | read/write | - |
|
||||
| Own peers | read/write | read/write |
|
||||
| Users | read/write | - |
|
||||
| Groups | read/write | read (own) |
|
||||
| Policies | read/write | - |
|
||||
| IdP Connectors | read/write | - |
|
||||
| Account settings | read/write | - |
|
||||
|
||||
### Deliverable
|
||||
Role-based access control enforced. User role has limited dashboard access.
|
||||
|
||||
---
|
||||
|
||||
## Architecture Diagram
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ NETBIRD DEPLOYMENT │
|
||||
│ │
|
||||
│ ┌───────────────────────────────────────────────────────────┐ │
|
||||
│ │ CADDY (Reverse Proxy) │ │
|
||||
│ │ :80/:443 → routes to appropriate service │ │
|
||||
│ └───────────────────────────────────────────────────────────┘ │
|
||||
│ │ │ │ │ │
|
||||
│ ▼ ▼ ▼ ▼ │
|
||||
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────────┐ │
|
||||
│ │ Dashboard│ │Management│ │ Signal │ │ Zitadel │ │
|
||||
│ │ :80 │ │ :80 │ │ :80 │ │ :8080 (SQLite) │ │
|
||||
│ └──────────┘ └────┬─────┘ └──────────┘ └──────────────────┘ │
|
||||
│ │ │ │
|
||||
│ │ OIDC + Mgmt API │ │
|
||||
│ └──────────────────────────┘ │
|
||||
│ │
|
||||
│ ┌───────────────────────────────────────────────────────────┐ │
|
||||
│ │ Relay │ Coturn (TURN) │ │
|
||||
│ └───────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Implementation Order
|
||||
|
||||
1. **Stage 1** - Infrastructure (can be tested standalone)
|
||||
2. **Stage 2** - Role management (requires Stage 1)
|
||||
3. **Stage 3** - Remove legacy IdPs (can be done in parallel with Stage 2)
|
||||
4. **Stage 4** - Connector API (requires Stage 1)
|
||||
5. **Stage 5** - Permissions (requires Stage 2)
|
||||
|
||||
Stages 3 and 4 can be worked on in parallel after Stage 1 is complete.
|
||||
|
||||
---
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### First-Boot Credentials
|
||||
- Generated with `openssl rand -base64 32` in configure.sh
|
||||
- Printed to stdout during configuration (save them!)
|
||||
- Marked as "must change on first login" in Zitadel via `PASSWORDCHANGEREQUIRED=true`
|
||||
|
||||
### Zitadel Masterkey
|
||||
- Auto-generated if not provided (32 bytes)
|
||||
- Stored in environment variable (docker secret in production)
|
||||
- Used for Zitadel's internal encryption
|
||||
|
||||
### Service Account
|
||||
- Machine user created via `ZITADEL_FIRSTINSTANCE_ORG_MACHINE_*` env vars
|
||||
- PAT (Personal Access Token) generated with 1-year expiration
|
||||
- Used by NetBird management to call Zitadel API
|
||||
|
||||
---
|
||||
|
||||
## Migration Notes
|
||||
|
||||
Existing deployments using Auth0/Azure/Keycloak will need to:
|
||||
1. Deploy Zitadel alongside existing IdP
|
||||
2. Add existing IdP as Zitadel connector
|
||||
3. Migrate users (or let them re-authenticate via connector)
|
||||
4. Switch NetBird config to use Zitadel
|
||||
5. Remove old IdP
|
||||
|
||||
A migration guide should be provided as documentation.
|
||||
83
infrastructure_files/Caddyfile.tmpl
Normal file
83
infrastructure_files/Caddyfile.tmpl
Normal file
@@ -0,0 +1,83 @@
|
||||
{
|
||||
servers :80,:443 {
|
||||
protocols h1 h2c h2 h3
|
||||
}
|
||||
}
|
||||
|
||||
(security_headers) {
|
||||
header * {
|
||||
# HSTS - use 1 hour for testing, increase to 63072000 (2 years) in production
|
||||
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
|
||||
# Prevent MIME type sniffing
|
||||
X-Content-Type-Options "nosniff"
|
||||
# Clickjacking protection
|
||||
X-Frame-Options "SAMEORIGIN"
|
||||
# XSS protection
|
||||
X-XSS-Protection "1; mode=block"
|
||||
# Remove server header
|
||||
-Server
|
||||
# Referrer policy
|
||||
Referrer-Policy strict-origin-when-cross-origin
|
||||
}
|
||||
}
|
||||
|
||||
:${NETBIRD_CADDY_PORT}${CADDY_SECURE_DOMAIN} {
|
||||
import security_headers
|
||||
|
||||
# Relay
|
||||
reverse_proxy /relay* relay:${NETBIRD_RELAY_INTERNAL_PORT}
|
||||
|
||||
# Signal - WebSocket proxy
|
||||
reverse_proxy /ws-proxy/signal* signal:80
|
||||
# Signal - gRPC
|
||||
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
|
||||
|
||||
# Management - REST API
|
||||
reverse_proxy /api/* management:80
|
||||
# Management - WebSocket proxy
|
||||
reverse_proxy /ws-proxy/management* management:80
|
||||
# Management - gRPC
|
||||
reverse_proxy /management.ManagementService/* h2c://management:80
|
||||
|
||||
# Zitadel - Admin API
|
||||
reverse_proxy /zitadel.admin.v1.AdminService/* h2c://zitadel:8080
|
||||
reverse_proxy /admin/v1/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - Auth API
|
||||
reverse_proxy /zitadel.auth.v1.AuthService/* h2c://zitadel:8080
|
||||
reverse_proxy /auth/v1/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - Management API
|
||||
reverse_proxy /zitadel.management.v1.ManagementService/* h2c://zitadel:8080
|
||||
reverse_proxy /management/v1/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - System API
|
||||
reverse_proxy /zitadel.system.v1.SystemService/* h2c://zitadel:8080
|
||||
reverse_proxy /system/v1/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - User API v2
|
||||
reverse_proxy /zitadel.user.v2.UserService/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - Assets
|
||||
reverse_proxy /assets/v1/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - UI (login, console, etc.)
|
||||
reverse_proxy /ui/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - OIDC endpoints
|
||||
reverse_proxy /oidc/v1/* h2c://zitadel:8080
|
||||
reverse_proxy /oauth/v2/* h2c://zitadel:8080
|
||||
reverse_proxy /.well-known/openid-configuration h2c://zitadel:8080
|
||||
|
||||
# Zitadel - SAML
|
||||
reverse_proxy /saml/v2/* h2c://zitadel:8080
|
||||
|
||||
# Zitadel - Other
|
||||
reverse_proxy /openapi/* h2c://zitadel:8080
|
||||
reverse_proxy /debug/* h2c://zitadel:8080
|
||||
reverse_proxy /device/* h2c://zitadel:8080
|
||||
reverse_proxy /device h2c://zitadel:8080
|
||||
|
||||
# Dashboard - catch-all for frontend
|
||||
reverse_proxy /* dashboard:80
|
||||
}
|
||||
@@ -2,29 +2,20 @@
|
||||
|
||||
# Management API
|
||||
|
||||
# Management API port
|
||||
NETBIRD_MGMT_API_PORT=${NETBIRD_MGMT_API_PORT:-33073}
|
||||
# Management API endpoint address, used by the Dashboard
|
||||
NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
|
||||
# Management Certificate file path. These are generated by the Dashboard container
|
||||
NETBIRD_LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN
|
||||
NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/fullchain.pem"
|
||||
# Management Certificate key file path.
|
||||
NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/privkey.pem"
|
||||
# Management API endpoint address, used by the Dashboard (Caddy handles TLS)
|
||||
NETBIRD_MGMT_API_ENDPOINT=${NETBIRD_HTTP_PROTOCOL:-https}://$NETBIRD_DOMAIN
|
||||
# By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain
|
||||
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
|
||||
NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted}
|
||||
NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=${NETBIRD_MGMT_IDP_SIGNKEY_REFRESH:-false}
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=${NETBIRD_MGMT_DISABLE_DEFAULT_POLICY:-false}
|
||||
|
||||
# Signal
|
||||
NETBIRD_SIGNAL_PROTOCOL="http"
|
||||
NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-10000}
|
||||
NETBIRD_SIGNAL_PROTOCOL=${NETBIRD_HTTP_PROTOCOL:-https}
|
||||
NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-443}
|
||||
|
||||
# Relay
|
||||
NETBIRD_RELAY_DOMAIN=${NETBIRD_RELAY_DOMAIN:-$NETBIRD_DOMAIN}
|
||||
NETBIRD_RELAY_PORT=${NETBIRD_RELAY_PORT:-33080}
|
||||
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT:-rel://$NETBIRD_RELAY_DOMAIN:$NETBIRD_RELAY_PORT}
|
||||
# Relay (internal port for Caddy reverse proxy)
|
||||
NETBIRD_RELAY_INTERNAL_PORT=${NETBIRD_RELAY_INTERNAL_PORT:-80}
|
||||
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT:-${NETBIRD_RELAY_PROTO:-rels}://$NETBIRD_DOMAIN:${NETBIRD_RELAY_PORT:-443}}
|
||||
# Relay auth secret
|
||||
NETBIRD_RELAY_AUTH_SECRET=
|
||||
|
||||
@@ -141,3 +132,57 @@ export NETBIRD_RELAY_ENDPOINT
|
||||
export NETBIRD_RELAY_AUTH_SECRET
|
||||
export NETBIRD_RELAY_TAG
|
||||
export NETBIRD_MGMT_DISABLE_DEFAULT_POLICY
|
||||
|
||||
# Zitadel IdP Configuration
|
||||
ZITADEL_TAG=${ZITADEL_TAG:-"v4.7.6"}
|
||||
# Zitadel masterkey (32 bytes, auto-generated if not set)
|
||||
ZITADEL_MASTERKEY=
|
||||
# Zitadel admin credentials (auto-generated if not set)
|
||||
ZITADEL_ADMIN_USERNAME=
|
||||
ZITADEL_ADMIN_PASSWORD=
|
||||
# Zitadel external configuration
|
||||
ZITADEL_EXTERNALSECURE=${ZITADEL_EXTERNALSECURE:-true}
|
||||
ZITADEL_EXTERNALPORT=${ZITADEL_EXTERNALPORT:-443}
|
||||
ZITADEL_TLS_MODE=${ZITADEL_TLS_MODE:-external}
|
||||
# Zitadel PAT expiration (1 year from startup)
|
||||
ZITADEL_PAT_EXPIRATION=
|
||||
# Zitadel management endpoint
|
||||
ZITADEL_MANAGEMENT_ENDPOINT=${NETBIRD_HTTP_PROTOCOL:-https}://$NETBIRD_DOMAIN/management/v1
|
||||
# HTTP protocol (http or https)
|
||||
NETBIRD_HTTP_PROTOCOL=${NETBIRD_HTTP_PROTOCOL:-https}
|
||||
# Caddy configuration
|
||||
NETBIRD_CADDY_PORT=${NETBIRD_CADDY_PORT:-80}
|
||||
CADDY_SECURE_DOMAIN=
|
||||
|
||||
# Zitadel OIDC endpoints
|
||||
NETBIRD_AUTH_AUTHORITY=${NETBIRD_HTTP_PROTOCOL:-https}://$NETBIRD_DOMAIN
|
||||
NETBIRD_AUTH_TOKEN_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/oauth/v2/token
|
||||
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/.well-known/openid-configuration
|
||||
NETBIRD_AUTH_JWT_CERTS=${NETBIRD_AUTH_AUTHORITY}/.well-known/jwks.json
|
||||
NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/oauth/v2/authorize
|
||||
NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/oauth/v2/device_authorization
|
||||
NETBIRD_AUTH_USER_ID_CLAIM=${NETBIRD_AUTH_USER_ID_CLAIM:-sub}
|
||||
NETBIRD_AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES:-"openid profile email offline_access"}
|
||||
|
||||
# Zitadel exports
|
||||
export ZITADEL_TAG
|
||||
export ZITADEL_MASTERKEY
|
||||
export ZITADEL_ADMIN_USERNAME
|
||||
export ZITADEL_ADMIN_PASSWORD
|
||||
export ZITADEL_EXTERNALSECURE
|
||||
export ZITADEL_EXTERNALPORT
|
||||
export ZITADEL_TLS_MODE
|
||||
export ZITADEL_PAT_EXPIRATION
|
||||
export ZITADEL_MANAGEMENT_ENDPOINT
|
||||
export NETBIRD_HTTP_PROTOCOL
|
||||
export NETBIRD_CADDY_PORT
|
||||
export CADDY_SECURE_DOMAIN
|
||||
export NETBIRD_AUTH_AUTHORITY
|
||||
export NETBIRD_AUTH_TOKEN_ENDPOINT
|
||||
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT
|
||||
export NETBIRD_AUTH_JWT_CERTS
|
||||
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT
|
||||
export NETBIRD_AUTH_USER_ID_CLAIM
|
||||
export NETBIRD_AUTH_SUPPORTED_SCOPES
|
||||
export NETBIRD_RELAY_INTERNAL_PORT
|
||||
|
||||
@@ -1,261 +1,137 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if ! which curl >/dev/null 2>&1; then
|
||||
echo "This script uses curl fetch OpenID configuration from IDP."
|
||||
echo "Please install curl and re-run the script https://curl.se/"
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! which jq >/dev/null 2>&1; then
|
||||
echo "This script uses jq to load OpenID configuration from IDP."
|
||||
echo "Please install jq and re-run the script https://stedolan.github.io/jq/"
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
# Check required dependencies
|
||||
for cmd in curl jq envsubst openssl; do
|
||||
if ! which $cmd >/dev/null 2>&1; then
|
||||
echo "This script requires $cmd. Please install it and re-run."
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
# Source configuration
|
||||
source setup.env
|
||||
source base.setup.env
|
||||
|
||||
if ! which envsubst >/dev/null 2>&1; then
|
||||
echo "envsubst is needed to run this script"
|
||||
if [[ $(uname) == "Darwin" ]]; then
|
||||
echo "you can install it with homebrew (https://brew.sh):"
|
||||
echo "brew install gettext"
|
||||
else
|
||||
if which apt-get >/dev/null 2>&1; then
|
||||
echo "you can install it by running"
|
||||
echo "apt-get update && apt-get install gettext-base"
|
||||
else
|
||||
echo "you can install it by installing the package gettext with your package manager"
|
||||
fi
|
||||
fi
|
||||
# Validate required variables
|
||||
if [[ -z "$NETBIRD_DOMAIN" ]]; then
|
||||
echo "NETBIRD_DOMAIN is not set, please update your setup.env file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "x-$NETBIRD_DOMAIN" == "x-" ]]; then
|
||||
echo NETBIRD_DOMAIN is not set, please update your setup.env file
|
||||
echo If you are migrating from old versions, you might need to update your variables prefixes from
|
||||
echo WIRETRUSTEE_.. TO NETBIRD_
|
||||
# Check database configuration if using external database
|
||||
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "postgres" && -z "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" ]]; then
|
||||
echo "Error: NETBIRD_STORE_CONFIG_ENGINE=postgres but NETBIRD_STORE_ENGINE_POSTGRES_DSN is not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if PostgreSQL is set as the store engine
|
||||
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "postgres" ]]; then
|
||||
# Exit if 'NETBIRD_STORE_ENGINE_POSTGRES_DSN' is not set
|
||||
if [[ -z "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" ]]; then
|
||||
echo "Warning: NETBIRD_STORE_CONFIG_ENGINE=postgres but NETBIRD_STORE_ENGINE_POSTGRES_DSN is not set."
|
||||
echo "Please add the following line to your setup.env file:"
|
||||
echo 'NETBIRD_STORE_ENGINE_POSTGRES_DSN="host=<PG_HOST> user=<PG_USER> password=<PG_PASSWORD> dbname=<PG_DB_NAME> port=<PG_PORT>"'
|
||||
exit 1
|
||||
fi
|
||||
export NETBIRD_STORE_ENGINE_POSTGRES_DSN
|
||||
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "mysql" && -z "$NETBIRD_STORE_ENGINE_MYSQL_DSN" ]]; then
|
||||
echo "Error: NETBIRD_STORE_CONFIG_ENGINE=mysql but NETBIRD_STORE_ENGINE_MYSQL_DSN is not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if MySQL is set as the store engine
|
||||
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "mysql" ]]; then
|
||||
# Exit if 'NETBIRD_STORE_ENGINE_MYSQL_DSN' is not set
|
||||
if [[ -z "$NETBIRD_STORE_ENGINE_MYSQL_DSN" ]]; then
|
||||
echo "Warning: NETBIRD_STORE_CONFIG_ENGINE=mysql but NETBIRD_STORE_ENGINE_MYSQL_DSN is not set."
|
||||
echo "Please add the following line to your setup.env file:"
|
||||
echo 'NETBIRD_STORE_ENGINE_MYSQL_DSN="<username>:<password>@tcp(127.0.0.1:3306)/<database>"'
|
||||
exit 1
|
||||
fi
|
||||
export NETBIRD_STORE_ENGINE_MYSQL_DSN
|
||||
fi
|
||||
|
||||
# local development or tests
|
||||
# Configure for local development vs production
|
||||
if [[ $NETBIRD_DOMAIN == "localhost" || $NETBIRD_DOMAIN == "127.0.0.1" ]]; then
|
||||
export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN="netbird.selfhosted"
|
||||
export NETBIRD_MGMT_API_ENDPOINT=http://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
|
||||
unset NETBIRD_MGMT_API_CERT_FILE
|
||||
unset NETBIRD_MGMT_API_CERT_KEY_FILE
|
||||
fi
|
||||
|
||||
# if not provided, we generate a turn password
|
||||
if [[ "x-$TURN_PASSWORD" == "x-" ]]; then
|
||||
export TURN_PASSWORD=$(openssl rand -base64 32 | sed 's/=//g')
|
||||
fi
|
||||
|
||||
TURN_EXTERNAL_IP_CONFIG="#"
|
||||
|
||||
if [[ "x-$NETBIRD_TURN_EXTERNAL_IP" == "x-" ]]; then
|
||||
echo "discovering server's public IP"
|
||||
IP=$(curl -s -4 https://jsonip.com | jq -r '.ip')
|
||||
if [[ "x-$IP" != "x-" ]]; then
|
||||
TURN_EXTERNAL_IP_CONFIG="external-ip=$IP"
|
||||
else
|
||||
echo "unable to discover server's public IP"
|
||||
fi
|
||||
export NETBIRD_MGMT_API_ENDPOINT="http://$NETBIRD_DOMAIN"
|
||||
export NETBIRD_HTTP_PROTOCOL="http"
|
||||
export ZITADEL_EXTERNALSECURE="false"
|
||||
export ZITADEL_EXTERNALPORT="80"
|
||||
export ZITADEL_TLS_MODE="disabled"
|
||||
export NETBIRD_RELAY_PROTO="rel"
|
||||
else
|
||||
echo "${NETBIRD_TURN_EXTERNAL_IP}"| egrep '([0-9]{1,3}\.){3}[0-9]{1,3}$' > /dev/null
|
||||
if [[ $? -eq 0 ]]; then
|
||||
echo "using provided server's public IP"
|
||||
TURN_EXTERNAL_IP_CONFIG="external-ip=$NETBIRD_TURN_EXTERNAL_IP"
|
||||
else
|
||||
echo "provided NETBIRD_TURN_EXTERNAL_IP $NETBIRD_TURN_EXTERNAL_IP is invalid, please correct it and try again"
|
||||
exit 1
|
||||
fi
|
||||
export NETBIRD_HTTP_PROTOCOL="https"
|
||||
export ZITADEL_EXTERNALSECURE="true"
|
||||
export ZITADEL_EXTERNALPORT="443"
|
||||
export ZITADEL_TLS_MODE="external"
|
||||
export NETBIRD_RELAY_PROTO="rels"
|
||||
export CADDY_SECURE_DOMAIN=", $NETBIRD_DOMAIN:443"
|
||||
fi
|
||||
|
||||
# Auto-generate secrets if not provided
|
||||
[[ -z "$TURN_PASSWORD" ]] && export TURN_PASSWORD=$(openssl rand -base64 32 | sed 's/=//g')
|
||||
[[ -z "$NETBIRD_RELAY_AUTH_SECRET" ]] && export NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed 's/=//g')
|
||||
[[ -z "$ZITADEL_MASTERKEY" ]] && export ZITADEL_MASTERKEY=$(openssl rand -base64 32 | head -c 32)
|
||||
|
||||
# Generate Zitadel admin credentials if not provided
|
||||
if [[ -z "$ZITADEL_ADMIN_USERNAME" ]]; then
|
||||
export ZITADEL_ADMIN_USERNAME="admin@${NETBIRD_DOMAIN}"
|
||||
fi
|
||||
if [[ -z "$ZITADEL_ADMIN_PASSWORD" ]]; then
|
||||
export ZITADEL_ADMIN_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')!"
|
||||
fi
|
||||
|
||||
# Set Zitadel PAT expiration (1 year from now)
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
export ZITADEL_PAT_EXPIRATION=$(date -u -v+1y "+%Y-%m-%dT%H:%M:%SZ")
|
||||
else
|
||||
export ZITADEL_PAT_EXPIRATION=$(date -u -d "+1 year" "+%Y-%m-%dT%H:%M:%SZ")
|
||||
fi
|
||||
|
||||
# Discover external IP for TURN
|
||||
TURN_EXTERNAL_IP_CONFIG="#"
|
||||
if [[ -z "$NETBIRD_TURN_EXTERNAL_IP" ]]; then
|
||||
IP=$(curl -s -4 https://jsonip.com | jq -r '.ip' 2>/dev/null || echo "")
|
||||
[[ -n "$IP" ]] && TURN_EXTERNAL_IP_CONFIG="external-ip=$IP"
|
||||
elif echo "$NETBIRD_TURN_EXTERNAL_IP" | grep -qE '^([0-9]{1,3}\.){3}[0-9]{1,3}$'; then
|
||||
TURN_EXTERNAL_IP_CONFIG="external-ip=$NETBIRD_TURN_EXTERNAL_IP"
|
||||
fi
|
||||
export TURN_EXTERNAL_IP_CONFIG
|
||||
|
||||
# if not provided, we generate a relay auth secret
|
||||
if [[ "x-$NETBIRD_RELAY_AUTH_SECRET" == "x-" ]]; then
|
||||
export NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed 's/=//g')
|
||||
fi
|
||||
|
||||
artifacts_path="./artifacts"
|
||||
mkdir -p $artifacts_path
|
||||
# Configure endpoints
|
||||
export NETBIRD_AUTH_AUTHORITY="${NETBIRD_HTTP_PROTOCOL}://${NETBIRD_DOMAIN}"
|
||||
export NETBIRD_AUTH_TOKEN_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/oauth/v2/token"
|
||||
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/.well-known/openid-configuration"
|
||||
export NETBIRD_AUTH_JWT_CERTS="${NETBIRD_AUTH_AUTHORITY}/.well-known/jwks.json"
|
||||
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/oauth/v2/authorize"
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/oauth/v2/device_authorization"
|
||||
export ZITADEL_MANAGEMENT_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/management/v1"
|
||||
export NETBIRD_RELAY_ENDPOINT="${NETBIRD_RELAY_PROTO}://${NETBIRD_DOMAIN}:${ZITADEL_EXTERNALPORT}"
|
||||
|
||||
# Volume names (with backwards compatibility)
|
||||
MGMT_VOLUMENAME="${VOLUME_PREFIX}${MGMT_VOLUMESUFFIX}"
|
||||
SIGNAL_VOLUMENAME="${VOLUME_PREFIX}${SIGNAL_VOLUMESUFFIX}"
|
||||
LETSENCRYPT_VOLUMENAME="${VOLUME_PREFIX}${LETSENCRYPT_VOLUMESUFFIX}"
|
||||
# if volume with wiretrustee- prefix already exists, use it, else create new with netbird-
|
||||
OLD_PREFIX='wiretrustee-'
|
||||
if docker volume ls | grep -q "${OLD_PREFIX}${MGMT_VOLUMESUFFIX}"; then
|
||||
MGMT_VOLUMENAME="${OLD_PREFIX}${MGMT_VOLUMESUFFIX}"
|
||||
fi
|
||||
if docker volume ls | grep -q "${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}"; then
|
||||
SIGNAL_VOLUMENAME="${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}"
|
||||
fi
|
||||
if docker volume ls | grep -q "${OLD_PREFIX}${LETSENCRYPT_VOLUMESUFFIX}"; then
|
||||
LETSENCRYPT_VOLUMENAME="${OLD_PREFIX}${LETSENCRYPT_VOLUMESUFFIX}"
|
||||
docker volume ls 2>/dev/null | grep -q "${OLD_PREFIX}${MGMT_VOLUMESUFFIX}" && MGMT_VOLUMENAME="${OLD_PREFIX}${MGMT_VOLUMESUFFIX}"
|
||||
docker volume ls 2>/dev/null | grep -q "${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}" && SIGNAL_VOLUMENAME="${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}"
|
||||
export MGMT_VOLUMENAME SIGNAL_VOLUMENAME
|
||||
|
||||
# Preserve existing encryption key
|
||||
if test -f 'management.json'; then
|
||||
encKey=$(jq -r ".DataStoreEncryptionKey" management.json 2>/dev/null || echo "null")
|
||||
[[ "$encKey" != "null" && -n "$encKey" ]] && export NETBIRD_DATASTORE_ENC_KEY="$encKey"
|
||||
fi
|
||||
|
||||
export MGMT_VOLUMENAME
|
||||
export SIGNAL_VOLUMENAME
|
||||
export LETSENCRYPT_VOLUMENAME
|
||||
|
||||
#backwards compatibility after migrating to generic OIDC with Auth0
|
||||
if [[ -z "${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT}" ]]; then
|
||||
|
||||
if [[ -z "${NETBIRD_AUTH0_DOMAIN}" ]]; then
|
||||
# not a backward compatible state
|
||||
echo "NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT property must be set in the setup.env file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "It seems like you provided an old setup.env file."
|
||||
echo "Since the release of v0.8.10, we introduced a new set of properties."
|
||||
echo "The script is backward compatible and will continue automatically."
|
||||
echo "In the future versions it will be deprecated. Please refer to the documentation to learn about the changes http://netbird.io/docs/getting-started/self-hosting"
|
||||
|
||||
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="https://${NETBIRD_AUTH0_DOMAIN}/.well-known/openid-configuration"
|
||||
export NETBIRD_USE_AUTH0="true"
|
||||
export NETBIRD_AUTH_AUDIENCE=${NETBIRD_AUTH0_AUDIENCE}
|
||||
export NETBIRD_AUTH_CLIENT_ID=${NETBIRD_AUTH0_CLIENT_ID}
|
||||
fi
|
||||
|
||||
echo "loading OpenID configuration from ${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT} to the openid-configuration.json file"
|
||||
curl "${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT}" -q -o ${artifacts_path}/openid-configuration.json
|
||||
|
||||
export NETBIRD_AUTH_AUTHORITY=$(jq -r '.issuer' ${artifacts_path}/openid-configuration.json)
|
||||
export NETBIRD_AUTH_JWT_CERTS=$(jq -r '.jwks_uri' ${artifacts_path}/openid-configuration.json)
|
||||
export NETBIRD_AUTH_TOKEN_ENDPOINT=$(jq -r '.token_endpoint' ${artifacts_path}/openid-configuration.json)
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=$(jq -r '.device_authorization_endpoint' ${artifacts_path}/openid-configuration.json)
|
||||
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT=$(jq -r '.authorization_endpoint' ${artifacts_path}/openid-configuration.json)
|
||||
|
||||
if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then
|
||||
# user enabled Device Authorization Grant feature
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted"
|
||||
fi
|
||||
|
||||
if [ "$NETBIRD_TOKEN_SOURCE" = "idToken" ]; then
|
||||
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN=true
|
||||
fi
|
||||
|
||||
# Check if letsencrypt was disabled
|
||||
if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then
|
||||
export NETBIRD_DASHBOARD_ENDPOINT="https://$NETBIRD_DOMAIN:443"
|
||||
export NETBIRD_SIGNAL_ENDPOINT="https://$NETBIRD_DOMAIN:$NETBIRD_SIGNAL_PORT"
|
||||
export NETBIRD_RELAY_ENDPOINT="rels://$NETBIRD_DOMAIN:$NETBIRD_RELAY_PORT/relay"
|
||||
|
||||
echo "Letsencrypt was disabled, the Https-endpoints cannot be used anymore"
|
||||
echo " and a reverse-proxy with Https needs to be placed in front of netbird!"
|
||||
echo "The following forwards have to be setup:"
|
||||
echo "- $NETBIRD_DASHBOARD_ENDPOINT -http-> dashboard:80"
|
||||
echo "- $NETBIRD_MGMT_API_ENDPOINT/api -http-> management:$NETBIRD_MGMT_API_PORT"
|
||||
echo "- $NETBIRD_MGMT_API_ENDPOINT/management.ManagementService/ -grpc-> management:$NETBIRD_MGMT_API_PORT"
|
||||
echo "- $NETBIRD_SIGNAL_ENDPOINT/signalexchange.SignalExchange/ -grpc-> signal:80"
|
||||
echo "- $NETBIRD_RELAY_ENDPOINT/ -http-> relay:33080"
|
||||
echo "You most likely also have to change NETBIRD_MGMT_API_ENDPOINT in base.setup.env and port-mappings in docker-compose.yml.tmpl and rerun this script."
|
||||
echo " The target of the forwards depends on your setup. Beware of the gRPC protocol instead of http for management and signal!"
|
||||
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
|
||||
echo ""
|
||||
|
||||
unset NETBIRD_LETSENCRYPT_DOMAIN
|
||||
unset NETBIRD_MGMT_API_CERT_FILE
|
||||
unset NETBIRD_MGMT_API_CERT_KEY_FILE
|
||||
fi
|
||||
|
||||
if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then
|
||||
export NETBIRD_SIGNAL_PROTOCOL="https"
|
||||
fi
|
||||
|
||||
# Check if management identity provider is set
|
||||
if [ -n "$NETBIRD_MGMT_IDP" ]; then
|
||||
EXTRA_CONFIG={}
|
||||
|
||||
# extract extra config from all env prefixed with NETBIRD_IDP_MGMT_EXTRA_
|
||||
for var in ${!NETBIRD_IDP_MGMT_EXTRA_*}; do
|
||||
# convert key snake case to camel case
|
||||
key=$(
|
||||
echo "${var#NETBIRD_IDP_MGMT_EXTRA_}" | awk -F "_" \
|
||||
'{for (i=1; i<=NF; i++) {output=output substr($i,1,1) tolower(substr($i,2))} print output}'
|
||||
)
|
||||
value="${!var}"
|
||||
|
||||
echo "$var"
|
||||
EXTRA_CONFIG=$(jq --arg k "$key" --arg v "$value" '.[$k] = $v' <<<"$EXTRA_CONFIG")
|
||||
done
|
||||
|
||||
export NETBIRD_MGMT_IDP
|
||||
export NETBIRD_IDP_MGMT_CLIENT_ID
|
||||
export NETBIRD_IDP_MGMT_CLIENT_SECRET
|
||||
export NETBIRD_IDP_MGMT_EXTRA_CONFIG=$EXTRA_CONFIG
|
||||
else
|
||||
export NETBIRD_IDP_MGMT_EXTRA_CONFIG={}
|
||||
fi
|
||||
|
||||
IFS=',' read -r -a REDIRECT_URL_PORTS <<< "$NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS"
|
||||
REDIRECT_URLS=""
|
||||
for port in "${REDIRECT_URL_PORTS[@]}"; do
|
||||
REDIRECT_URLS+="\"http://localhost:${port}\","
|
||||
# Create artifacts directory and backup existing files
|
||||
artifacts_path="./artifacts"
|
||||
mkdir -p "$artifacts_path"
|
||||
bkp_postfix="$(date +%s)"
|
||||
for file in docker-compose.yml management.json turnserver.conf Caddyfile; do
|
||||
[[ -f "${artifacts_path}/${file}" ]] && cp "${artifacts_path}/${file}" "${artifacts_path}/${file}.bkp.${bkp_postfix}"
|
||||
done
|
||||
|
||||
export NETBIRD_AUTH_PKCE_REDIRECT_URLS=${REDIRECT_URLS%,}
|
||||
# Generate configuration files
|
||||
envsubst < docker-compose.yml.tmpl > "$artifacts_path/docker-compose.yml"
|
||||
envsubst < management.json.tmpl | jq . > "$artifacts_path/management.json"
|
||||
envsubst < turnserver.conf.tmpl > "$artifacts_path/turnserver.conf"
|
||||
envsubst < Caddyfile.tmpl > "$artifacts_path/Caddyfile"
|
||||
|
||||
# Remove audience for providers that do not support it
|
||||
if [ "$NETBIRD_DASH_AUTH_USE_AUDIENCE" = "false" ]; then
|
||||
export NETBIRD_DASH_AUTH_AUDIENCE=none
|
||||
export NETBIRD_AUTH_PKCE_AUDIENCE=
|
||||
fi
|
||||
|
||||
# Read the encryption key
|
||||
if test -f 'management.json'; then
|
||||
encKey=$(jq -r ".DataStoreEncryptionKey" management.json)
|
||||
if [[ "$encKey" != "null" ]]; then
|
||||
export NETBIRD_DATASTORE_ENC_KEY=$encKey
|
||||
|
||||
fi
|
||||
fi
|
||||
|
||||
env | grep NETBIRD
|
||||
|
||||
bkp_postfix="$(date +%s)"
|
||||
if test -f "${artifacts_path}/docker-compose.yml"; then
|
||||
cp $artifacts_path/docker-compose.yml "${artifacts_path}/docker-compose.yml.bkp.${bkp_postfix}"
|
||||
fi
|
||||
|
||||
if test -f "${artifacts_path}/management.json"; then
|
||||
cp $artifacts_path/management.json "${artifacts_path}/management.json.bkp.${bkp_postfix}"
|
||||
fi
|
||||
|
||||
if test -f "${artifacts_path}/turnserver.conf"; then
|
||||
cp ${artifacts_path}/turnserver.conf "${artifacts_path}/turnserver.conf.bkp.${bkp_postfix}"
|
||||
fi
|
||||
envsubst <docker-compose.yml.tmpl >$artifacts_path/docker-compose.yml
|
||||
envsubst <management.json.tmpl | jq . >$artifacts_path/management.json
|
||||
envsubst <turnserver.conf.tmpl >$artifacts_path/turnserver.conf
|
||||
# Print summary
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo " NetBird Configuration Complete"
|
||||
echo "=========================================="
|
||||
echo " Domain: $NETBIRD_DOMAIN"
|
||||
echo " Protocol: $NETBIRD_HTTP_PROTOCOL"
|
||||
echo " Zitadel: $ZITADEL_TAG (SQLite)"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo " ADMIN CREDENTIALS (save these!):"
|
||||
echo " Username: $ZITADEL_ADMIN_USERNAME"
|
||||
echo " Password: $ZITADEL_ADMIN_PASSWORD"
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "To start NetBird:"
|
||||
echo " cd $artifacts_path && docker compose up -d"
|
||||
echo ""
|
||||
|
||||
@@ -7,108 +7,137 @@ x-default: &default
|
||||
max-file: '2'
|
||||
|
||||
services:
|
||||
# Caddy reverse proxy
|
||||
caddy:
|
||||
<<: *default
|
||||
image: caddy:2
|
||||
networks: [netbird]
|
||||
ports:
|
||||
- '443:443'
|
||||
- '443:443/udp'
|
||||
- '80:80'
|
||||
volumes:
|
||||
- netbird-caddy-data:/data
|
||||
- ./Caddyfile:/etc/caddy/Caddyfile:ro
|
||||
depends_on:
|
||||
zitadel:
|
||||
condition: service_healthy
|
||||
|
||||
# UI dashboard
|
||||
dashboard:
|
||||
<<: *default
|
||||
image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG
|
||||
ports:
|
||||
- 80:80
|
||||
- 443:443
|
||||
networks: [netbird]
|
||||
environment:
|
||||
# Endpoints
|
||||
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
||||
- NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
||||
# OIDC
|
||||
- AUTH_AUDIENCE=$NETBIRD_DASH_AUTH_AUDIENCE
|
||||
- AUTH_AUDIENCE=$NETBIRD_AUTH_CLIENT_ID
|
||||
- AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
|
||||
- AUTH_CLIENT_SECRET=$NETBIRD_AUTH_CLIENT_SECRET
|
||||
- AUTH_CLIENT_SECRET=
|
||||
- AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY
|
||||
- USE_AUTH0=$NETBIRD_USE_AUTH0
|
||||
- USE_AUTH0=false
|
||||
- AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES
|
||||
- AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI
|
||||
- AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI
|
||||
- AUTH_REDIRECT_URI=/nb-auth
|
||||
- AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
|
||||
- NETBIRD_TOKEN_SOURCE=$NETBIRD_TOKEN_SOURCE
|
||||
# SSL
|
||||
- NGINX_SSL_PORT=443
|
||||
# Letsencrypt
|
||||
- LETSENCRYPT_DOMAIN=$NETBIRD_LETSENCRYPT_DOMAIN
|
||||
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
|
||||
volumes:
|
||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
|
||||
- LETSENCRYPT_DOMAIN=none
|
||||
depends_on:
|
||||
zitadel:
|
||||
condition: service_healthy
|
||||
|
||||
# Signal
|
||||
signal:
|
||||
<<: *default
|
||||
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
|
||||
depends_on:
|
||||
- dashboard
|
||||
networks: [netbird]
|
||||
volumes:
|
||||
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
|
||||
ports:
|
||||
- $NETBIRD_SIGNAL_PORT:80
|
||||
# # port and command for Let's Encrypt validation
|
||||
# - 443:443
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||
command: [
|
||||
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
|
||||
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
||||
"--log-file", "console"
|
||||
]
|
||||
command: ["--log-file", "console"]
|
||||
|
||||
# Relay
|
||||
relay:
|
||||
<<: *default
|
||||
image: netbirdio/relay:$NETBIRD_RELAY_TAG
|
||||
networks: [netbird]
|
||||
environment:
|
||||
- NB_LOG_LEVEL=info
|
||||
- NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_PORT
|
||||
- NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_ENDPOINT
|
||||
# todo: change to a secure secret
|
||||
- NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
|
||||
ports:
|
||||
- $NETBIRD_RELAY_PORT:$NETBIRD_RELAY_PORT
|
||||
- NB_LOG_LEVEL=info
|
||||
- NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_INTERNAL_PORT
|
||||
- NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_ENDPOINT
|
||||
- NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
|
||||
|
||||
# Management
|
||||
management:
|
||||
<<: *default
|
||||
image: netbirdio/management:$NETBIRD_MANAGEMENT_TAG
|
||||
depends_on:
|
||||
- dashboard
|
||||
networks: [netbird]
|
||||
volumes:
|
||||
- $MGMT_VOLUMENAME:/var/lib/netbird
|
||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
|
||||
- ./management.json:/etc/netbird/management.json
|
||||
ports:
|
||||
- $NETBIRD_MGMT_API_PORT:443 #API port
|
||||
# # command for Let's Encrypt validation without dashboard container
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||
command: [
|
||||
"--port", "443",
|
||||
"--port", "80",
|
||||
"--log-file", "console",
|
||||
"--log-level", "info",
|
||||
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
|
||||
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
|
||||
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
|
||||
]
|
||||
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN",
|
||||
"--idp-sign-key-refresh-enabled"
|
||||
]
|
||||
environment:
|
||||
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN
|
||||
- NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN
|
||||
|
||||
depends_on:
|
||||
zitadel:
|
||||
condition: service_healthy
|
||||
|
||||
# Coturn
|
||||
coturn:
|
||||
<<: *default
|
||||
image: coturn/coturn:$COTURN_TAG
|
||||
#domainname: $TURN_DOMAIN # only needed when TLS is enabled
|
||||
volumes:
|
||||
- ./turnserver.conf:/etc/turnserver.conf:ro
|
||||
# - ./privkey.pem:/etc/coturn/private/privkey.pem:ro
|
||||
# - ./cert.pem:/etc/coturn/certs/cert.pem:ro
|
||||
network_mode: host
|
||||
command:
|
||||
- -c /etc/turnserver.conf
|
||||
|
||||
# Zitadel Identity Provider (single container with SQLite)
|
||||
zitadel:
|
||||
<<: *default
|
||||
image: ghcr.io/zitadel/zitadel:$ZITADEL_TAG
|
||||
networks: [netbird]
|
||||
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
||||
environment:
|
||||
- ZITADEL_LOG_LEVEL=info
|
||||
- ZITADEL_MASTERKEY=$ZITADEL_MASTERKEY
|
||||
- ZITADEL_EXTERNALSECURE=$ZITADEL_EXTERNALSECURE
|
||||
- ZITADEL_TLS_ENABLED=false
|
||||
- ZITADEL_EXTERNALPORT=$ZITADEL_EXTERNALPORT
|
||||
- ZITADEL_EXTERNALDOMAIN=$NETBIRD_DOMAIN
|
||||
# SQLite database (no separate PostgreSQL container needed)
|
||||
- ZITADEL_DATABASE_SQLITE_PATH=/data/zitadel.db
|
||||
# First instance: Admin user
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_USERNAME=$ZITADEL_ADMIN_USERNAME
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_PASSWORD=$ZITADEL_ADMIN_PASSWORD
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_PASSWORDCHANGEREQUIRED=true
|
||||
# First instance: Service account for NetBird management
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_USERNAME=netbird-service
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_NAME=NetBird Service Account
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINEKEY_TYPE=1
|
||||
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_EXPIRATIONDATE=$ZITADEL_PAT_EXPIRATION
|
||||
volumes:
|
||||
- netbird-zitadel-data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "/app/zitadel", "ready"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 30s
|
||||
|
||||
volumes:
|
||||
$MGMT_VOLUMENAME:
|
||||
$SIGNAL_VOLUMENAME:
|
||||
$LETSENCRYPT_VOLUMENAME:
|
||||
netbird-caddy-data:
|
||||
netbird-zitadel-data:
|
||||
|
||||
networks:
|
||||
netbird:
|
||||
|
||||
@@ -45,18 +45,18 @@
|
||||
"Engine": "$NETBIRD_STORE_CONFIG_ENGINE"
|
||||
},
|
||||
"HttpConfig": {
|
||||
"Address": "0.0.0.0:$NETBIRD_MGMT_API_PORT",
|
||||
"Address": "0.0.0.0:80",
|
||||
"AuthIssuer": "$NETBIRD_AUTH_AUTHORITY",
|
||||
"AuthAudience": "$NETBIRD_AUTH_AUDIENCE",
|
||||
"AuthAudience": "$NETBIRD_AUTH_CLIENT_ID",
|
||||
"AuthKeysLocation": "$NETBIRD_AUTH_JWT_CERTS",
|
||||
"AuthUserIDClaim": "$NETBIRD_AUTH_USER_ID_CLAIM",
|
||||
"CertFile":"$NETBIRD_MGMT_API_CERT_FILE",
|
||||
"CertKey":"$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
||||
"IdpSignKeyRefreshEnabled": $NETBIRD_MGMT_IDP_SIGNKEY_REFRESH,
|
||||
"OIDCConfigEndpoint":"$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT"
|
||||
"CertFile": "",
|
||||
"CertKey": "",
|
||||
"IdpSignKeyRefreshEnabled": true,
|
||||
"OIDCConfigEndpoint": "$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT"
|
||||
},
|
||||
"IdpManagerConfig": {
|
||||
"ManagerType": "$NETBIRD_MGMT_IDP",
|
||||
"ManagerType": "zitadel",
|
||||
"ClientConfig": {
|
||||
"Issuer": "$NETBIRD_AUTH_AUTHORITY",
|
||||
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
||||
@@ -64,40 +64,28 @@
|
||||
"ClientSecret": "$NETBIRD_IDP_MGMT_CLIENT_SECRET",
|
||||
"GrantType": "client_credentials"
|
||||
},
|
||||
"ExtraConfig": $NETBIRD_IDP_MGMT_EXTRA_CONFIG,
|
||||
"Auth0ClientCredentials": null,
|
||||
"AzureClientCredentials": null,
|
||||
"KeycloakClientCredentials": null,
|
||||
"ZitadelClientCredentials": null
|
||||
},
|
||||
"ExtraConfig": {
|
||||
"ManagementEndpoint": "$ZITADEL_MANAGEMENT_ENDPOINT"
|
||||
}
|
||||
},
|
||||
"DeviceAuthorizationFlow": {
|
||||
"Provider": "$NETBIRD_AUTH_DEVICE_AUTH_PROVIDER",
|
||||
"Provider": "hosted",
|
||||
"ProviderConfig": {
|
||||
"Audience": "$NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE",
|
||||
"AuthorizationEndpoint": "",
|
||||
"Domain": "$NETBIRD_AUTH0_DOMAIN",
|
||||
"ClientID": "$NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID",
|
||||
"ClientSecret": "",
|
||||
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
||||
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT",
|
||||
"Scope": "$NETBIRD_AUTH_DEVICE_AUTH_SCOPE",
|
||||
"UseIDToken": $NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN,
|
||||
"RedirectURLs": null
|
||||
}
|
||||
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
||||
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT",
|
||||
"Scope": "openid"
|
||||
}
|
||||
},
|
||||
"PKCEAuthorizationFlow": {
|
||||
"ProviderConfig": {
|
||||
"Audience": "$NETBIRD_AUTH_PKCE_AUDIENCE",
|
||||
"ClientID": "$NETBIRD_AUTH_CLIENT_ID",
|
||||
"ClientSecret": "$NETBIRD_AUTH_CLIENT_SECRET",
|
||||
"Domain": "",
|
||||
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||
"AuthorizationEndpoint": "$NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT",
|
||||
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
||||
"Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES",
|
||||
"RedirectURLs": [$NETBIRD_AUTH_PKCE_REDIRECT_URLS],
|
||||
"UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN,
|
||||
"DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN,
|
||||
"LoginFlag": $NETBIRD_AUTH_PKCE_LOGIN_FLAG
|
||||
"Scope": "openid profile email offline_access",
|
||||
"RedirectURLs": ["http://localhost:53000/", "http://localhost:54000/"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,117 +1,65 @@
|
||||
## example file, you can copy this file to setup.env and update its values
|
||||
##
|
||||
# NetBird Self-Hosted Setup Configuration
|
||||
# Copy this file to setup.env and configure the required values
|
||||
|
||||
# Image tags
|
||||
# you can force specific tags for each component; will be set to latest if empty
|
||||
# -------------------------------------------
|
||||
# Required: Domain Configuration
|
||||
# -------------------------------------------
|
||||
# Your NetBird domain (e.g., netbird.mydomain.com)
|
||||
NETBIRD_DOMAIN=""
|
||||
|
||||
# -------------------------------------------
|
||||
# Optional: Image Tags
|
||||
# -------------------------------------------
|
||||
# Leave empty to use 'latest' for all components
|
||||
NETBIRD_DASHBOARD_TAG=""
|
||||
NETBIRD_SIGNAL_TAG=""
|
||||
NETBIRD_MANAGEMENT_TAG=""
|
||||
COTURN_TAG=""
|
||||
NETBIRD_RELAY_TAG=""
|
||||
|
||||
# Dashboard domain. e.g. app.mydomain.com
|
||||
NETBIRD_DOMAIN=""
|
||||
# Zitadel version (default: v2.64.1)
|
||||
ZITADEL_TAG=""
|
||||
|
||||
# TURN server domain. e.g. turn.mydomain.com
|
||||
# if not specified it will assume NETBIRD_DOMAIN
|
||||
# -------------------------------------------
|
||||
# Optional: TURN Server Configuration
|
||||
# -------------------------------------------
|
||||
# TURN server domain (defaults to NETBIRD_DOMAIN)
|
||||
NETBIRD_TURN_DOMAIN=""
|
||||
|
||||
# TURN server public IP address
|
||||
# required for a connection involving peers in
|
||||
# the same network as the server and external peers
|
||||
# usually matches the IP for the domain set in NETBIRD_TURN_DOMAIN
|
||||
# Required for peers behind NAT to connect
|
||||
NETBIRD_TURN_EXTERNAL_IP=""
|
||||
|
||||
# -------------------------------------------
|
||||
# OIDC
|
||||
# e.g., https://example.eu.auth0.com/.well-known/openid-configuration
|
||||
# Optional: Database Configuration
|
||||
# -------------------------------------------
|
||||
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
|
||||
# The default setting is to transmit the audience to the IDP during authorization. However,
|
||||
# if your IDP does not have this capability, you can turn this off by setting it to false.
|
||||
#NETBIRD_DASH_AUTH_USE_AUDIENCE=false
|
||||
NETBIRD_AUTH_AUDIENCE=""
|
||||
# e.g. netbird-client
|
||||
NETBIRD_AUTH_CLIENT_ID=""
|
||||
# indicates the scopes that will be requested to the IDP
|
||||
NETBIRD_AUTH_SUPPORTED_SCOPES=""
|
||||
# NETBIRD_AUTH_CLIENT_SECRET is required only by Google workspace.
|
||||
# NETBIRD_AUTH_CLIENT_SECRET=""
|
||||
# if you want to use a custom claim for the user ID instead of 'sub', set it here
|
||||
# NETBIRD_AUTH_USER_ID_CLAIM=""
|
||||
# indicates whether to use Auth0 or not: true or false
|
||||
NETBIRD_USE_AUTH0="false"
|
||||
# if your IDP provider doesn't support fragmented URIs, configure custom
|
||||
# redirect and silent redirect URIs, these will be concatenated into your NETBIRD_DOMAIN domain.
|
||||
# NETBIRD_AUTH_REDIRECT_URI="/peers"
|
||||
# NETBIRD_AUTH_SILENT_REDIRECT_URI="/add-peers"
|
||||
# Updates the preference to use id tokens instead of access token on dashboard
|
||||
# Okta and Gitlab IDPs can benefit from this
|
||||
# NETBIRD_TOKEN_SOURCE="idToken"
|
||||
# Store engine: sqlite (default), postgres, or mysql
|
||||
NETBIRD_STORE_CONFIG_ENGINE=""
|
||||
|
||||
# For PostgreSQL:
|
||||
# NETBIRD_STORE_ENGINE_POSTGRES_DSN="host=<HOST> user=<USER> password=<PASS> dbname=<DB> port=5432"
|
||||
|
||||
# For MySQL:
|
||||
# NETBIRD_STORE_ENGINE_MYSQL_DSN="<user>:<pass>@tcp(127.0.0.1:3306)/<db>"
|
||||
|
||||
# -------------------------------------------
|
||||
# OIDC Device Authorization Flow
|
||||
# Optional: Extra Settings
|
||||
# -------------------------------------------
|
||||
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"
|
||||
NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID=""
|
||||
# Some IDPs requires different audience, scopes and to use id token for device authorization flow
|
||||
# you can customize here:
|
||||
NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
|
||||
NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid"
|
||||
NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN=false
|
||||
# -------------------------------------------
|
||||
# OIDC PKCE Authorization Flow
|
||||
# -------------------------------------------
|
||||
# Comma separated port numbers. if already in use, PKCE flow will choose an available port from the list as an alternative
|
||||
# eg. 53000,54000
|
||||
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS="53000"
|
||||
# -------------------------------------------
|
||||
# IDP Management
|
||||
# -------------------------------------------
|
||||
# eg. zitadel, auth0, azure, keycloak
|
||||
NETBIRD_MGMT_IDP="none"
|
||||
# Some IDPs requires different client id and client secret for management api
|
||||
NETBIRD_IDP_MGMT_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
|
||||
NETBIRD_IDP_MGMT_CLIENT_SECRET=""
|
||||
# Required when setting up with Keycloak "https://<YOUR_KEYCLOAK_HOST_AND_PORT>/admin/realms/netbird"
|
||||
# NETBIRD_IDP_MGMT_EXTRA_ADMIN_ENDPOINT=
|
||||
# With some IDPs may be needed enabling automatic refresh of signing keys on expire
|
||||
# NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=false
|
||||
# NETBIRD_IDP_MGMT_EXTRA_ variables. See https://docs.netbird.io/selfhosted/identity-providers for more information about your IDP of choice.
|
||||
# -------------------------------------------
|
||||
# Letsencrypt
|
||||
# -------------------------------------------
|
||||
# Disable letsencrypt
|
||||
# if disabled, cannot use HTTPS anymore and requires setting up a reverse-proxy to do it instead
|
||||
NETBIRD_DISABLE_LETSENCRYPT=false
|
||||
# e.g. hello@mydomain.com
|
||||
NETBIRD_LETSENCRYPT_EMAIL=""
|
||||
# -------------------------------------------
|
||||
# Extra settings
|
||||
# -------------------------------------------
|
||||
# Disable anonymous metrics collection, see more information at https://netbird.io/docs/FAQ/metrics-collection
|
||||
# Disable anonymous metrics (default: false)
|
||||
NETBIRD_DISABLE_ANONYMOUS_METRICS=false
|
||||
# DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted
|
||||
|
||||
# DNS domain for peer resolution (default: netbird.selfhosted)
|
||||
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted
|
||||
# Disable default all-to-all policy for new accounts
|
||||
|
||||
# Disable default all-to-all policy for new accounts (default: false)
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=false
|
||||
|
||||
# -------------------------------------------
|
||||
# Relay settings
|
||||
# Advanced: Zitadel Client IDs
|
||||
# -------------------------------------------
|
||||
# Relay server domain. e.g. relay.mydomain.com
|
||||
# if not specified it will assume NETBIRD_DOMAIN
|
||||
NETBIRD_RELAY_DOMAIN=""
|
||||
|
||||
# Relay server connection port. If none is supplied
|
||||
# it will default to 33080
|
||||
# should be updated to match TLS-port of reverse proxy when netbird is running behind reverse proxy
|
||||
NETBIRD_RELAY_PORT=""
|
||||
|
||||
# Management API connecting port. If none is supplied
|
||||
# it will default to 33073
|
||||
# should be updated to match TLS-port of reverse proxy when netbird is running behind reverse proxy
|
||||
NETBIRD_MGMT_API_PORT=""
|
||||
|
||||
# Signal service connecting port. If none is supplied
|
||||
# it will default to 10000
|
||||
# should be updated to match TLS-port of reverse proxy when netbird is running behind reverse proxy
|
||||
NETBIRD_SIGNAL_PORT=""
|
||||
# These are auto-generated by Zitadel on first boot
|
||||
# Only set these if migrating from an existing Zitadel setup
|
||||
# NETBIRD_AUTH_CLIENT_ID=""
|
||||
# NETBIRD_AUTH_CLIENT_ID_CLI=""
|
||||
# NETBIRD_IDP_MGMT_CLIENT_ID=""
|
||||
# NETBIRD_IDP_MGMT_CLIENT_SECRET=""
|
||||
|
||||
@@ -587,42 +587,40 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context, store cache
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
userData, err := am.idpManager.GetAllAccounts(ctx)
|
||||
// Get all users from IdP
|
||||
idpUsers, err := am.idpManager.GetAllUsers(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.WithContext(ctx).Infof("%d entries received from IdP management", len(userData))
|
||||
log.WithContext(ctx).Infof("%d users received from IdP management", len(idpUsers))
|
||||
|
||||
// If the Identity Provider does not support writing AppMetadata,
|
||||
// in cases like this, we expect it to return all users in an "unset" field.
|
||||
// We iterate over the users in the "unset" field, look up their AccountID in our store, and
|
||||
// update their AppMetadata with the AccountID.
|
||||
if unsetData, ok := userData[idp.UnsetAccountID]; ok {
|
||||
for _, user := range unsetData {
|
||||
accountID, err := am.Store.GetAccountByUser(ctx, user.ID)
|
||||
if err == nil {
|
||||
data := userData[accountID.Id]
|
||||
if data == nil {
|
||||
data = make([]*idp.UserData, 0, 1)
|
||||
}
|
||||
|
||||
user.AppMetadata.WTAccountID = accountID.Id
|
||||
|
||||
userData[accountID.Id] = append(data, user)
|
||||
}
|
||||
}
|
||||
// Create a map for quick lookup of IdP users by ID
|
||||
idpUserMap := make(map[string]*idp.UserData, len(idpUsers))
|
||||
for _, user := range idpUsers {
|
||||
idpUserMap[user.ID] = user
|
||||
}
|
||||
|
||||
// Group IdP users by their account ID from NetBird's database
|
||||
// NetBird DB is the source of truth for account membership
|
||||
accountUsers := make(map[string][]*idp.UserData)
|
||||
for _, idpUser := range idpUsers {
|
||||
account, err := am.Store.GetAccountByUser(ctx, idpUser.ID)
|
||||
if err != nil {
|
||||
// User exists in IdP but not in NetBird - skip
|
||||
continue
|
||||
}
|
||||
accountUsers[account.Id] = append(accountUsers[account.Id], idpUser)
|
||||
}
|
||||
delete(userData, idp.UnsetAccountID)
|
||||
|
||||
rcvdUsers := 0
|
||||
for accountID, users := range userData {
|
||||
for accountID, users := range accountUsers {
|
||||
rcvdUsers += len(users)
|
||||
err = am.cacheManager.Set(am.ctx, accountID, users, cacheEntryExpiration())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
log.WithContext(ctx).Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData))
|
||||
log.WithContext(ctx).Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(accountUsers))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -742,10 +740,6 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
|
||||
if err != nil {
|
||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||
}
|
||||
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
}
|
||||
return "", err
|
||||
@@ -757,35 +751,6 @@ func isNil(i idp.Manager) bool {
|
||||
return i == nil || reflect.ValueOf(i).IsNil()
|
||||
}
|
||||
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) {
|
||||
// user can be nil if it wasn't found (e.g., just created)
|
||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user != nil && user.AppMetadata.WTAccountID == accountID {
|
||||
// it was already set, so we skip the unnecessary update
|
||||
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
|
||||
accountID, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
|
||||
}
|
||||
// refresh cache to reflect the update
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) (any, []cacheStore.Option, error) {
|
||||
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
|
||||
accountIDString := fmt.Sprintf("%v", accountID)
|
||||
@@ -797,28 +762,32 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
|
||||
|
||||
// Get users belonging to this account from NetBird's database
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userData, err := am.idpManager.GetAccount(ctx, accountIDString)
|
||||
// Get all users from IdP (we don't have per-account queries anymore)
|
||||
idpUsers, err := am.idpManager.GetAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("%d entries received from IdP management for account %s", len(userData), accountIDString)
|
||||
log.WithContext(ctx).Debugf("%d total users in IdP, matching against account %s", len(idpUsers), accountIDString)
|
||||
|
||||
dataMap := make(map[string]*idp.UserData, len(userData))
|
||||
for _, datum := range userData {
|
||||
dataMap[datum.ID] = datum
|
||||
// Create a map for quick lookup of IdP users by ID
|
||||
idpUserMap := make(map[string]*idp.UserData, len(idpUsers))
|
||||
for _, user := range idpUsers {
|
||||
idpUserMap[user.ID] = user
|
||||
}
|
||||
|
||||
// Match account users against IdP users
|
||||
matchedUserData := make([]*idp.UserData, 0)
|
||||
for _, user := range accountUsers {
|
||||
if user.IsServiceUser {
|
||||
continue
|
||||
}
|
||||
datum, ok := dataMap[user.Id]
|
||||
datum, ok := idpUserMap[user.Id]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Warnf("user %s not found in IDP", user.Id)
|
||||
continue
|
||||
@@ -972,7 +941,7 @@ func (am *DefaultAccountManager) lookupCache(ctx context.Context, accountUsers m
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count and user invite status
|
||||
// isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count
|
||||
func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool {
|
||||
userDataMap := make(map[string]*idp.UserData, len(data))
|
||||
for _, datum := range data {
|
||||
@@ -980,15 +949,10 @@ func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers
|
||||
}
|
||||
|
||||
// the accountUsers ID list of non integration users from store, we check if cache has all of them
|
||||
// as result of for loop knownUsersCount will have number of users are not presented in the cashed
|
||||
// as result of for loop knownUsersCount will have number of users are not presented in the cached
|
||||
knownUsersCount := len(accountUsers)
|
||||
for user, loggedInOnce := range accountUsers {
|
||||
if datum, ok := userDataMap[user]; ok {
|
||||
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
|
||||
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
|
||||
log.WithContext(ctx).Infof("user %s has a pending invite and has logged in once, cache invalid", user)
|
||||
return false
|
||||
}
|
||||
for user := range accountUsers {
|
||||
if _, ok := userDataMap[user]; ok {
|
||||
knownUsersCount--
|
||||
continue
|
||||
}
|
||||
@@ -1078,12 +1042,6 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
return err
|
||||
}
|
||||
|
||||
// we should register the account ID to this user's metadata in our IDP manager
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, userAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1110,11 +1068,6 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, newAccount.Id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, newAccount.Id, activity.UserJoined, nil)
|
||||
|
||||
return newAccount.Id, nil
|
||||
@@ -1139,11 +1092,6 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if newUser.PendingApproval {
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, map[string]any{"pending_approval": true})
|
||||
} else {
|
||||
@@ -1155,34 +1103,30 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
|
||||
|
||||
// redeemInvite checks whether user has been invited and redeems the invite
|
||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
|
||||
// only possible with the enabled IdP manager
|
||||
if am.idpManager == nil {
|
||||
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||
// Get user from NetBird's database (source of truth for authorization data)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return status.Errorf(status.NotFound, "user %s not found in the IdP", userID)
|
||||
return status.Errorf(status.NotFound, "user %s not found", userID)
|
||||
}
|
||||
|
||||
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
|
||||
// Check if user has a pending invite that needs to be redeemed
|
||||
if user.PendingInvite {
|
||||
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID)
|
||||
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
|
||||
// Our job is to just reload cache.
|
||||
go func() {
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID)
|
||||
return
|
||||
}
|
||||
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID)
|
||||
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
|
||||
}()
|
||||
|
||||
// Update user to mark invite as redeemed
|
||||
user.PendingInvite = false
|
||||
err = am.Store.SaveUser(ctx, user)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to redeem invite for user %s: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", userID, accountID)
|
||||
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1,959 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Auth0Manager auth0 manager client instance
|
||||
type Auth0Manager struct {
|
||||
authIssuer string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// Auth0ClientConfig auth0 manager client configurations
|
||||
type Auth0ClientConfig struct {
|
||||
Audience string
|
||||
AuthIssuer string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// auth0JWTRequest payload struct to request a JWT Token
|
||||
type auth0JWTRequest struct {
|
||||
Audience string `json:"audience"`
|
||||
AuthIssuer string `json:"auth_issuer"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
GrantType string `json:"grant_type"`
|
||||
}
|
||||
|
||||
// Auth0Credentials auth0 authentication information
|
||||
type Auth0Credentials struct {
|
||||
clientConfig Auth0ClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// createUserRequest is a user create request
|
||||
type createUserRequest struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
AppMeta AppMetadata `json:"app_metadata"`
|
||||
Connection string `json:"connection"`
|
||||
Password string `json:"password"`
|
||||
VerifyEmail bool `json:"verify_email"`
|
||||
}
|
||||
|
||||
// userExportJobRequest is a user export request struct
|
||||
type userExportJobRequest struct {
|
||||
Format string `json:"format"`
|
||||
Fields []map[string]string `json:"fields"`
|
||||
}
|
||||
|
||||
// userExportJobResponse is a user export response struct
|
||||
type userExportJobResponse struct {
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
ConnectionID string `json:"connection_id"`
|
||||
Format string `json:"format"`
|
||||
Limit int `json:"limit"`
|
||||
Connection string `json:"connection"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// userExportJobStatusResponse is a user export status response struct
|
||||
type userExportJobStatusResponse struct {
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
ConnectionID string `json:"connection_id"`
|
||||
Format string `json:"format"`
|
||||
Limit int `json:"limit"`
|
||||
Location string `json:"location"`
|
||||
Connection string `json:"connection"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// userVerificationJobRequest is a user verification request struct
|
||||
type userVerificationJobRequest struct {
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// auth0Profile represents an Auth0 user profile response
|
||||
type auth0Profile struct {
|
||||
AccountID string `json:"wt_account_id"`
|
||||
PendingInvite bool `json:"wt_pending_invite"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastLogin string `json:"last_login"`
|
||||
}
|
||||
|
||||
// Connections represents a single Auth0 connection
|
||||
// https://auth0.com/docs/api/management/v2/connections/get-connections
|
||||
type Connection struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"display_name"`
|
||||
IsDomainConnection bool `json:"is_domain_connection"`
|
||||
Realms []string `json:"realms"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
Options ConnectionOptions `json:"options"`
|
||||
}
|
||||
|
||||
type ConnectionOptions struct {
|
||||
DomainAliases []string `json:"domain_aliases"`
|
||||
}
|
||||
|
||||
// NewAuth0Manager creates a new instance of the Auth0Manager
|
||||
func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.AuthIssuer == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, AuthIssuer is missing")
|
||||
}
|
||||
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientID is missing")
|
||||
}
|
||||
|
||||
if config.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientSecret is missing")
|
||||
}
|
||||
|
||||
if config.Audience == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, Audience is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
credentials := &Auth0Credentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &Auth0Manager{
|
||||
authIssuer: config.AuthIssuer,
|
||||
credentials: credentials,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from Auth0
|
||||
func (c *Auth0Credentials) jwtStillValid() bool {
|
||||
return !c.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(c.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token
|
||||
func (c *Auth0Credentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
var res *http.Response
|
||||
reqURL := c.clientConfig.AuthIssuer + "/oauth/token"
|
||||
|
||||
p, err := c.helper.Marshal(auth0JWTRequest(c.clientConfig))
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
payload := strings.NewReader(string(p))
|
||||
|
||||
req, err := http.NewRequest("POST", reqURL, payload)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for idp manager")
|
||||
|
||||
res, err = c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if c.appMetrics != nil {
|
||||
c.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return res, fmt.Errorf("unable to get token, statusCode %d", res.StatusCode)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = c.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = json.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the Auth0 Management API
|
||||
func (c *Auth0Credentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if c.appMetrics != nil {
|
||||
c.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// If jwtToken has an expires time and we have enough time to do a request return immediately
|
||||
if c.jwtStillValid() {
|
||||
return c.jwtToken, nil
|
||||
}
|
||||
|
||||
res, err := c.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return c.jwtToken, err
|
||||
}
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing get jwt token response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
jwtToken, err := c.parseRequestJWTResponse(res.Body)
|
||||
if err != nil {
|
||||
return c.jwtToken, err
|
||||
}
|
||||
|
||||
c.jwtToken = jwtToken
|
||||
|
||||
return c.jwtToken, nil
|
||||
}
|
||||
|
||||
func batchRequestUsersURL(authIssuer, accountID string, page int, perPage int) (string, url.Values, error) {
|
||||
u, err := url.Parse(authIssuer + "/api/v2/users")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("page", strconv.Itoa(page))
|
||||
q.Set("search_engine", "v3")
|
||||
q.Set("per_page", strconv.Itoa(perPage))
|
||||
q.Set("q", "app_metadata.wt_account_id:"+accountID)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String(), q, nil
|
||||
}
|
||||
|
||||
func requestByUserIDURL(authIssuer, userID string) string {
|
||||
return authIssuer + "/api/v2/users/" + userID
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile. Calls Auth0 API.
|
||||
func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var list []*UserData
|
||||
|
||||
// https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations
|
||||
// auth0 limitation of 1000 users via this endpoint
|
||||
resultsPerPage := 50
|
||||
for page := 0; page < 20; page++ {
|
||||
reqURL, query, err := batchRequestUsersURL(am.authIssuer, accountID, page, resultsPerPage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, strings.NewReader(query.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
res, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed requesting user data from IdP %s", string(body))
|
||||
}
|
||||
|
||||
var batch []UserData
|
||||
err = json.Unmarshal(body, &batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
|
||||
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for user := range batch {
|
||||
list = append(list, &batch[user])
|
||||
}
|
||||
|
||||
if len(batch) == 0 || len(batch) < resultsPerPage {
|
||||
log.WithContext(ctx).Debugf("finished loading users for accountID %s", accountID)
|
||||
return list, nil
|
||||
}
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from auth0 via ID
|
||||
func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := requestByUserIDURL(am.authIssuer, userID)
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
res, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var userData UserData
|
||||
err = json.Unmarshal(body, &userData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("unable to get UserData, statusCode %d", res.StatusCode)
|
||||
}
|
||||
|
||||
return &userData, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userId and metadata map
|
||||
func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error {
|
||||
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqURL := am.authIssuer + "/api/v2/users/" + userID
|
||||
|
||||
data, err := am.helper.Marshal(map[string]any{"app_metadata": appMetadata})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payload := strings.NewReader(string(data))
|
||||
|
||||
req, err := http.NewRequest("PATCH", reqURL, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.WithContext(ctx).Debugf("updating IdP metadata for user %s", userID)
|
||||
|
||||
res, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return fmt.Errorf("unable to update the appMetadata, statusCode %d", res.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildCreateUserRequestPayload(email, name, accountID, invitedByEmail string) (string, error) {
|
||||
invite := true
|
||||
req := &createUserRequest{
|
||||
Email: email,
|
||||
Name: name,
|
||||
AppMeta: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &invite,
|
||||
WTInvitedBy: invitedByEmail,
|
||||
},
|
||||
Connection: "Username-Password-Authentication",
|
||||
Password: GeneratePassword(8, 1, 1, 1),
|
||||
VerifyEmail: true,
|
||||
}
|
||||
|
||||
str, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(str), nil
|
||||
}
|
||||
|
||||
func buildUserExportRequest() (string, error) {
|
||||
req := &userExportJobRequest{}
|
||||
fields := make([]map[string]string, 0)
|
||||
|
||||
for _, field := range []string{"created_at", "last_login", "user_id", "email", "name"} {
|
||||
fields = append(fields, map[string]string{"name": field})
|
||||
}
|
||||
|
||||
fields = append(fields, map[string]string{
|
||||
"name": "app_metadata.wt_account_id",
|
||||
"export_as": "wt_account_id",
|
||||
})
|
||||
|
||||
fields = append(fields, map[string]string{
|
||||
"name": "app_metadata.wt_pending_invite",
|
||||
"export_as": "wt_pending_invite",
|
||||
})
|
||||
|
||||
req.Format = "json"
|
||||
req.Fields = fields
|
||||
|
||||
str, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(str), nil
|
||||
}
|
||||
|
||||
func (am *Auth0Manager) createRequest(
|
||||
ctx context.Context, method string, endpoint string, body io.Reader,
|
||||
) (*http.Request, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := am.authIssuer + endpoint
|
||||
|
||||
req, err := http.NewRequest(method, reqURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (am *Auth0Manager) createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
|
||||
req, err := am.createRequest(ctx, "POST", endpoint, strings.NewReader(payloadStr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
payloadString, err := buildUserExportRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exportJobReq, err := am.createPostRequest(ctx, "/api/v2/jobs/users-exports", payloadString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jobResp, err := am.httpClient.Do(exportJobReq)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = jobResp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if jobResp.StatusCode != 200 {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to update the appMetadata, statusCode %d", jobResp.StatusCode)
|
||||
}
|
||||
|
||||
var exportJobResp userExportJobResponse
|
||||
|
||||
body, err := io.ReadAll(jobResp.Body)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &exportJobResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if exportJobResp.ID == "" {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
||||
|
||||
done, downloadLink, err := am.checkExportJobStatus(ctx, exportJobResp.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed at getting status checks from exportJob; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if done {
|
||||
return am.downloadProfileExport(ctx, downloadLink)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed extracting user profiles from auth0")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list.
|
||||
// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with
|
||||
// the same email but different connections that are considered as separate accounts (e.g., Google and username/password).
|
||||
func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
|
||||
body, err := doGetReq(ctx, am.httpClient, reqURL, jwtToken.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
userResp := []*UserData{}
|
||||
|
||||
err = am.helper.Unmarshal(body, &userResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userResp, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in Auth0 Idp and sends an invite
|
||||
func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
|
||||
payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := am.createPostRequest(ctx, "/api/v2/users", payloadString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing create user response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var createResp UserData
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &createResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if createResp.ID == "" {
|
||||
return nil, fmt.Errorf("couldn't create user: response %v", resp)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("created user %s in account %s", createResp.ID, accountID)
|
||||
|
||||
return &createResp, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error {
|
||||
userVerificationReq := userVerificationJobRequest{
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
payload, err := am.helper.Marshal(userVerificationReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := am.createPostRequest(ctx, "/api/v2/jobs/verification-email", string(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing invite user response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to invite user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser from Auth0
|
||||
func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error {
|
||||
req, err := am.createRequest(ctx, http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("execute delete request: %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("close delete request body: %v", err)
|
||||
}
|
||||
}()
|
||||
if resp.StatusCode != 204 {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllConnections returns detailed list of all connections filtered by given params.
|
||||
// Note this method is not part of the IDP Manager interface as this is Auth0 specific.
|
||||
func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string) ([]Connection, error) {
|
||||
var connections []Connection
|
||||
|
||||
q := make(url.Values)
|
||||
q.Set("strategy", strings.Join(strategy, ","))
|
||||
|
||||
req, err := am.createRequest(ctx, http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
|
||||
if err != nil {
|
||||
return connections, err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("execute get connections request: %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return connections, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("close get connections request body: %v", err)
|
||||
}
|
||||
}()
|
||||
if resp.StatusCode != 200 {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return connections, fmt.Errorf("unable to get connections, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't read get connections response; %v", err)
|
||||
return connections, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &connections)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal get connection response; %v", err)
|
||||
return connections, err
|
||||
}
|
||||
|
||||
return connections, err
|
||||
}
|
||||
|
||||
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
||||
// If the status is "completed", then return the downloadLink
|
||||
func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string) (bool, string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 90*time.Second)
|
||||
defer cancel()
|
||||
retry := time.NewTicker(10 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.WithContext(ctx).Debugf("Export job status stopped...\n")
|
||||
return false, "", ctx.Err()
|
||||
case <-retry.C:
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
statusURL := am.authIssuer + "/api/v2/jobs/" + jobID
|
||||
body, err := doGetReq(ctx, am.httpClient, statusURL, jwtToken.AccessToken)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
var status userExportJobStatusResponse
|
||||
err = am.helper.Unmarshal(body, &status)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("current export job status is %v", status.Status)
|
||||
|
||||
if status.Status != "completed" {
|
||||
continue
|
||||
}
|
||||
|
||||
return true, status.Location, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// downloadProfileExport downloads user profiles from auth0 batch job
|
||||
func (am *Auth0Manager) downloadProfileExport(ctx context.Context, location string) (map[string][]*UserData, error) {
|
||||
body, err := doGetReq(ctx, am.httpClient, location, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyReader := bytes.NewReader(body)
|
||||
|
||||
gzipReader, err := gzip.NewReader(bodyReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(gzipReader)
|
||||
|
||||
res := make(map[string][]*UserData)
|
||||
for decoder.More() {
|
||||
profile := auth0Profile{}
|
||||
err = decoder.Decode(&profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if profile.AccountID != "" {
|
||||
if _, ok := res[profile.AccountID]; !ok {
|
||||
res[profile.AccountID] = []*UserData{}
|
||||
}
|
||||
res[profile.AccountID] = append(res[profile.AccountID],
|
||||
&UserData{
|
||||
ID: profile.UserID,
|
||||
Name: profile.Name,
|
||||
Email: profile.Email,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: profile.AccountID,
|
||||
WTPendingInvite: &profile.PendingInvite,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Boilerplate implementation for Get Requests.
|
||||
func doGetReq(ctx context.Context, client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if accessToken != "" {
|
||||
req.Header.Add("authorization", "Bearer "+accessToken)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing body for url %s: %v", url, err)
|
||||
}
|
||||
}()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
@@ -1,473 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
code int
|
||||
resBody string
|
||||
reqBody string
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
if req.Body != nil {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err == nil {
|
||||
c.reqBody = string(body)
|
||||
}
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: c.code,
|
||||
Body: io.NopCloser(strings.NewReader(c.resBody)),
|
||||
}, c.err
|
||||
}
|
||||
|
||||
type mockJsonParser struct {
|
||||
jsonParser JsonParser
|
||||
marshalErrorString string
|
||||
unmarshalErrorString string
|
||||
}
|
||||
|
||||
func (m *mockJsonParser) Marshal(v interface{}) ([]byte, error) {
|
||||
if m.marshalErrorString != "" {
|
||||
return nil, errors.New(m.marshalErrorString)
|
||||
}
|
||||
return m.jsonParser.Marshal(v)
|
||||
}
|
||||
|
||||
func (m *mockJsonParser) Unmarshal(data []byte, v interface{}) error {
|
||||
if m.unmarshalErrorString != "" {
|
||||
return errors.New(m.unmarshalErrorString)
|
||||
}
|
||||
return m.jsonParser.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
type mockAuth0Credentials struct {
|
||||
jwtToken JWTToken
|
||||
err error
|
||||
}
|
||||
|
||||
func (mc *mockAuth0Credentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return mc.jwtToken, mc.err
|
||||
}
|
||||
|
||||
func newTestJWT(t *testing.T, expInt int) string {
|
||||
t.Helper()
|
||||
now := time.Now()
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(time.Duration(expInt) * time.Second).Unix(),
|
||||
})
|
||||
var hmacSampleSecret []byte
|
||||
tokenString, err := token.SignedString(hmacSampleSecret)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func TestAuth0_RequestJWTToken(t *testing.T) {
|
||||
|
||||
type requestJWTTokenTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||
name: "Request Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
creds := Auth0Credentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
res, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "unable to read the response body")
|
||||
|
||||
jwtToken := JWTToken{}
|
||||
err = json.Unmarshal(body, &jwtToken)
|
||||
assert.NoError(t, err, "unable to parse the json input")
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth0_ParseRequestJWTResponse(t *testing.T) {
|
||||
type parseRequestJWTResponseTest struct {
|
||||
name string
|
||||
inputResBody string
|
||||
helper ManagerHelper
|
||||
expectedToken string
|
||||
expectedExpiresIn int
|
||||
assertErrFunc func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 100
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||
name: "Parse Good JWT Body",
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
expectedExpiresIn: exp,
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "no error was expected",
|
||||
}
|
||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||
name: "Parse Bad json JWT Body",
|
||||
inputResBody: "",
|
||||
helper: JsonParser{},
|
||||
expectedToken: "",
|
||||
expectedExpiresIn: 0,
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "json error was expected",
|
||||
}
|
||||
|
||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputResBody))
|
||||
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
creds := Auth0Credentials{
|
||||
clientConfig: config,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth0_JwtStillValid(t *testing.T) {
|
||||
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
creds := Auth0Credentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth0_Authenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
// expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
creds := Auth0Credentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
|
||||
|
||||
type updateUserAppMetadataTest struct {
|
||||
name string
|
||||
inputReqBody string
|
||||
expectedReqBody string
|
||||
appMetadata AppMetadata
|
||||
statusCode int
|
||||
helper ManagerHelper
|
||||
managerCreds ManagerCredentials
|
||||
assertErrFunc func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 15
|
||||
token := newTestJWT(t, exp)
|
||||
appMetadata := AppMetadata{WTAccountID: "ok"}
|
||||
|
||||
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
||||
name: "Bad Authentication",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: "",
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockAuth0Credentials{
|
||||
jwtToken: JWTToken{},
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
||||
name: "Bad Status Code",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockAuth0Credentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
||||
name: "Bad Response Parsing",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
statusCode: 400,
|
||||
helper: &mockJsonParser{marshalErrorString: "error"},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
||||
name: "Good request",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 200,
|
||||
helper: JsonParser{},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
invite := true
|
||||
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
|
||||
name: "Update Pending Invite",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":true}}", appMetadata.WTAccountID),
|
||||
appMetadata: AppMetadata{
|
||||
WTAccountID: "ok",
|
||||
WTPendingInvite: &invite,
|
||||
},
|
||||
statusCode: 200,
|
||||
helper: JsonParser{},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
||||
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputReqBody,
|
||||
code: testCase.statusCode,
|
||||
}
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
var creds ManagerCredentials
|
||||
if testCase.managerCreds != nil {
|
||||
creds = testCase.managerCreds
|
||||
} else {
|
||||
creds = &Auth0Credentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
}
|
||||
|
||||
manager := Auth0Manager{
|
||||
httpClient: &jwtReqClient,
|
||||
credentials: creds,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
err := manager.UpdateUserAppMetadata(context.Background(), "1", testCase.appMetadata)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equal(t, testCase.expectedReqBody, jwtReqClient.reqBody, "request body should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuth0Manager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig Auth0ClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := Auth0ClientConfig{
|
||||
AuthIssuer: "https://abc-auth0.eu.auth0.com",
|
||||
Audience: "https://abc-auth0.eu.auth0.com/api/v2/",
|
||||
ClientID: "abcdefg",
|
||||
ClientSecret: "supersecret",
|
||||
GrantType: "client_credentials",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Scenario With Config",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.ClientID = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "shouldn't return error when field empty",
|
||||
}
|
||||
|
||||
testCase3Config := defaultTestConfig
|
||||
testCase3Config.AuthIssuer = "abc-auth0.eu.auth0.com"
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,428 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api/v3"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// AuthentikManager authentik manager client instance.
|
||||
type AuthentikManager struct {
|
||||
apiClient *api.APIClient
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// AuthentikClientConfig authentik manager client configurations.
|
||||
type AuthentikClientConfig struct {
|
||||
Issuer string
|
||||
ClientID string
|
||||
Username string
|
||||
Password string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// AuthentikCredentials authentik authentication information.
|
||||
type AuthentikCredentials struct {
|
||||
clientConfig AuthentikClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// NewAuthentikManager creates a new instance of the AuthentikManager.
|
||||
func NewAuthentikManager(config AuthentikClientConfig,
|
||||
appMetrics telemetry.AppMetrics) (*AuthentikManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, clientID is missing")
|
||||
}
|
||||
|
||||
if config.Username == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Username is missing")
|
||||
}
|
||||
|
||||
if config.Password == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Password is missing")
|
||||
}
|
||||
|
||||
if config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, TokenEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.Issuer == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Issuer is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
// authentik client configuration
|
||||
issuerURL, err := url.Parse(config.Issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authentikConfig := api.NewConfiguration()
|
||||
authentikConfig.HTTPClient = httpClient
|
||||
authentikConfig.Host = issuerURL.Host
|
||||
authentikConfig.Scheme = issuerURL.Scheme
|
||||
|
||||
credentials := &AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &AuthentikManager{
|
||||
apiClient: api.NewAPIClient(authentikConfig),
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from authentik.
|
||||
func (ac *AuthentikCredentials) jwtStillValid() bool {
|
||||
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (ac *AuthentikCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", ac.clientConfig.ClientID)
|
||||
data.Set("username", ac.clientConfig.Username)
|
||||
data.Set("password", ac.clientConfig.Password)
|
||||
data.Set("grant_type", ac.clientConfig.GrantType)
|
||||
data.Set("scope", "goauthentik.io/api")
|
||||
|
||||
payload := strings.NewReader(data.Encode())
|
||||
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for authentik idp manager")
|
||||
|
||||
resp, err := ac.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if ac.appMetrics != nil {
|
||||
ac.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unable to get authentik token, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = ac.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = ac.helper.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the authentik management API.
|
||||
func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
ac.mux.Lock()
|
||||
defer ac.mux.Unlock()
|
||||
|
||||
if ac.appMetrics != nil {
|
||||
ac.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// reuse the token without requesting a new one if it is not expired,
|
||||
// and if expiry time is sufficient time available to make a request.
|
||||
if ac.jwtStillValid() {
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := ac.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
|
||||
ac.jwtToken = jwtToken
|
||||
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (am *AuthentikManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from authentik via ID.
|
||||
func (am *AuthentikManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userPk, err := strconv.ParseInt(userID, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, resp, err := am.apiClient.CoreApi.CoreUsersRetrieve(ctx, int32(userPk)).Execute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
userData := parseAuthentikUser(*user)
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (am *AuthentikManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
for index, user := range users {
|
||||
user.AppMetadata.WTAccountID = accountID
|
||||
users[index] = user
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *AuthentikManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in a Authentik account.
|
||||
func (am *AuthentikManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
users := make([]*UserData, 0)
|
||||
|
||||
page := int32(1)
|
||||
for {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Page(page).Execute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
for _, user := range userList.Results {
|
||||
users = append(users, parseAuthentikUser(user))
|
||||
}
|
||||
|
||||
page = int32(userList.GetPagination().Next)
|
||||
if userList.GetPagination().Next == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in authentik Idp and sends an invitation.
|
||||
func (am *AuthentikManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (am *AuthentikManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Email(email).Execute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, user := range userList.Results {
|
||||
users = append(users, parseAuthentikUser(user))
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *AuthentikManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Authentik
|
||||
func (am *AuthentikManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userPk, err := strconv.ParseInt(userID, 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.apiClient.CoreApi.CoreUsersDestroy(ctx, int32(userPk)).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close() // nolint
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *AuthentikManager) authenticationContext(ctx context.Context) (context.Context, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value := map[string]api.APIKey{
|
||||
"authentik": {
|
||||
Key: jwtToken.AccessToken,
|
||||
Prefix: jwtToken.TokenType,
|
||||
},
|
||||
}
|
||||
return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil
|
||||
}
|
||||
|
||||
func parseAuthentikUser(user api.User) *UserData {
|
||||
return &UserData{
|
||||
Email: *user.Email,
|
||||
Name: user.Name,
|
||||
ID: strconv.FormatInt(int64(user.Pk), 10),
|
||||
}
|
||||
}
|
||||
@@ -1,320 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewAuthentikManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig AuthentikClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := AuthentikClientConfig{
|
||||
ClientID: "client_id",
|
||||
Username: "username",
|
||||
Password: "password",
|
||||
TokenEndpoint: "https://localhost:8080/application/o/token/",
|
||||
Issuer: "https://localhost:8080/application/o/netbird/",
|
||||
GrantType: "client_credentials",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.ClientID = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing ClientID Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase3Config := defaultTestConfig
|
||||
testCase3Config.Username = ""
|
||||
|
||||
testCase3 := test{
|
||||
name: "Missing Username Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase4Config := defaultTestConfig
|
||||
testCase4Config.Password = ""
|
||||
|
||||
testCase4 := test{
|
||||
name: "Missing Password Configuration",
|
||||
inputConfig: testCase4Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase5Config := defaultTestConfig
|
||||
testCase5Config.GrantType = ""
|
||||
|
||||
testCase5 := test{
|
||||
name: "Missing GrantType Configuration",
|
||||
inputConfig: testCase5Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase6Config := defaultTestConfig
|
||||
testCase6Config.Issuer = ""
|
||||
|
||||
testCase6 := test{
|
||||
name: "Missing Issuer Configuration",
|
||||
inputConfig: testCase6Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewAuthentikManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthentikRequestJWTToken(t *testing.T) {
|
||||
type requestJWTTokenTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||
name: "Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
}
|
||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||
name: "Request Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputRespBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get authentik token, statusCode 400"),
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputRespBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := AuthentikClientConfig{}
|
||||
|
||||
creds := AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err, "unable to read the response body")
|
||||
|
||||
jwtToken := JWTToken{}
|
||||
err = testCase.helper.Unmarshal(body, &jwtToken)
|
||||
assert.NoError(t, err, "unable to parse the json input")
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthentikParseRequestJWTResponse(t *testing.T) {
|
||||
type parseRequestJWTResponseTest struct {
|
||||
name string
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedToken string
|
||||
expectedExpiresIn int
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 100
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||
name: "Parse Good JWT Body",
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
expectedExpiresIn: exp,
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "no error was expected",
|
||||
}
|
||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||
name: "Parse Bad json JWT Body",
|
||||
inputRespBody: "",
|
||||
helper: JsonParser{},
|
||||
expectedToken: "",
|
||||
expectedExpiresIn: 0,
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "json error was expected",
|
||||
}
|
||||
|
||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
|
||||
config := AuthentikClientConfig{}
|
||||
|
||||
creds := AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthentikJwtStillValid(t *testing.T) {
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
config := AuthentikClientConfig{}
|
||||
|
||||
creds := AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthentikAuthenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: nil,
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get authentik token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := AuthentikClientConfig{}
|
||||
|
||||
creds := AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,459 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const profileFields = "id,displayName,mail,userPrincipalName"
|
||||
|
||||
// AzureManager azure manager client instance.
|
||||
type AzureManager struct {
|
||||
ClientID string
|
||||
ObjectID string
|
||||
GraphAPIEndpoint string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// AzureClientConfig azure manager client configurations.
|
||||
type AzureClientConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ObjectID string
|
||||
GraphAPIEndpoint string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// AzureCredentials azure authentication information.
|
||||
type AzureCredentials struct {
|
||||
clientConfig AzureClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// azureProfile represents an azure user profile.
|
||||
type azureProfile map[string]any
|
||||
|
||||
// NewAzureManager creates a new instance of the AzureManager.
|
||||
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, clientID is missing")
|
||||
}
|
||||
|
||||
if config.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, ClientSecret is missing")
|
||||
}
|
||||
|
||||
if config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, TokenEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.GraphAPIEndpoint == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, GraphAPIEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.ObjectID == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, ObjectID is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
credentials := &AzureCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &AzureManager{
|
||||
ObjectID: config.ObjectID,
|
||||
ClientID: config.ClientID,
|
||||
GraphAPIEndpoint: config.GraphAPIEndpoint,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
|
||||
func (ac *AzureCredentials) jwtStillValid() bool {
|
||||
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", ac.clientConfig.ClientID)
|
||||
data.Set("client_secret", ac.clientConfig.ClientSecret)
|
||||
data.Set("grant_type", ac.clientConfig.GrantType)
|
||||
parsedURL, err := url.Parse(ac.clientConfig.GraphAPIEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// get base url and add "/.default" as scope
|
||||
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
|
||||
scopeURL := baseURL + "/.default"
|
||||
data.Set("scope", scopeURL)
|
||||
|
||||
payload := strings.NewReader(data.Encode())
|
||||
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for azure idp manager")
|
||||
|
||||
resp, err := ac.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if ac.appMetrics != nil {
|
||||
ac.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unable to get azure token, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = ac.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = ac.helper.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the azure Management API.
|
||||
func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
ac.mux.Lock()
|
||||
defer ac.mux.Unlock()
|
||||
|
||||
if ac.appMetrics != nil {
|
||||
ac.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// reuse the token without requesting a new one if it is not expired,
|
||||
// and if expiry time is sufficient time available to make a request.
|
||||
if ac.jwtStillValid() {
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := ac.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
|
||||
ac.jwtToken = jwtToken
|
||||
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in azure AD Idp.
|
||||
func (am *AzureManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get(ctx, "users/"+userID, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var profile azureProfile
|
||||
err = am.helper.Unmarshal(body, &profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get(ctx, "users/"+email, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
var profile azureProfile
|
||||
err = am.helper.Unmarshal(body, &profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
users = append(users, profile.userData())
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
for index, user := range users {
|
||||
user.AppMetadata.WTAccountID = accountID
|
||||
users[index] = user
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID.
|
||||
func (am *AzureManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *AzureManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Azure.
|
||||
func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID))
|
||||
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.WithContext(ctx).Debugf("delete idp user %s", userID)
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in an Azure AD account.
|
||||
func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
users := make([]*UserData, 0)
|
||||
|
||||
q := url.Values{}
|
||||
q.Add("$select", profileFields)
|
||||
q.Add("$top", "500")
|
||||
|
||||
for nextLink := "users"; nextLink != ""; {
|
||||
body, err := am.get(ctx, nextLink, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var profiles struct {
|
||||
Value []azureProfile
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
}
|
||||
err = am.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, profile := range profiles.Value {
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
|
||||
nextLink = profiles.NextLink
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (am *AzureManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var reqURL string
|
||||
if strings.HasPrefix(resource, "https") {
|
||||
// Already an absolute URL for paging
|
||||
reqURL = resource
|
||||
} else {
|
||||
reqURL = fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// userData construct user data from keycloak profile.
|
||||
func (ap azureProfile) userData() *UserData {
|
||||
id, ok := ap["id"].(string)
|
||||
if !ok {
|
||||
id = ""
|
||||
}
|
||||
|
||||
email, ok := ap["userPrincipalName"].(string)
|
||||
if !ok {
|
||||
email = ""
|
||||
}
|
||||
|
||||
name, ok := ap["displayName"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
|
||||
return &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: id,
|
||||
}
|
||||
}
|
||||
@@ -1,178 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAzureJwtStillValid(t *testing.T) {
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
config := AzureClientConfig{}
|
||||
|
||||
creds := AzureCredentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureAuthenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: nil,
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get azure token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := AzureClientConfig{}
|
||||
|
||||
creds := AzureCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProfile(t *testing.T) {
|
||||
type azureProfileTest struct {
|
||||
name string
|
||||
invite bool
|
||||
inputProfile azureProfile
|
||||
expectedUserData UserData
|
||||
}
|
||||
|
||||
azureProfileTestCase1 := azureProfileTest{
|
||||
name: "Good Request",
|
||||
invite: false,
|
||||
inputProfile: azureProfile{
|
||||
"id": "test1",
|
||||
"displayName": "John Doe",
|
||||
"userPrincipalName": "test1@test.com",
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
Email: "test1@test.com",
|
||||
Name: "John Doe",
|
||||
ID: "test1",
|
||||
},
|
||||
}
|
||||
|
||||
azureProfileTestCase2 := azureProfileTest{
|
||||
name: "Missing User ID",
|
||||
invite: true,
|
||||
inputProfile: azureProfile{
|
||||
"displayName": "John Doe",
|
||||
"userPrincipalName": "test2@test.com",
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
Email: "test2@test.com",
|
||||
Name: "John Doe",
|
||||
},
|
||||
}
|
||||
|
||||
azureProfileTestCase3 := azureProfileTest{
|
||||
name: "Missing User Name",
|
||||
invite: false,
|
||||
inputProfile: azureProfile{
|
||||
"id": "test3",
|
||||
"userPrincipalName": "test3@test.com",
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
ID: "test3",
|
||||
Email: "test3@test.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
||||
userData := testCase.inputProfile.userData()
|
||||
|
||||
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
|
||||
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,263 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2/google"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
"google.golang.org/api/option"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// GoogleWorkspaceManager Google Workspace manager client instance.
|
||||
type GoogleWorkspaceManager struct {
|
||||
usersService *admin.UsersService
|
||||
CustomerID string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// GoogleWorkspaceClientConfig Google Workspace manager client configurations.
|
||||
type GoogleWorkspaceClientConfig struct {
|
||||
ServiceAccountKey string
|
||||
CustomerID string
|
||||
}
|
||||
|
||||
// GoogleWorkspaceCredentials Google Workspace authentication information.
|
||||
type GoogleWorkspaceCredentials struct {
|
||||
clientConfig GoogleWorkspaceClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
func (gc *GoogleWorkspaceCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
// NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager.
|
||||
func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.CustomerID == "" {
|
||||
return nil, fmt.Errorf("google IdP configuration is incomplete, CustomerID is missing")
|
||||
}
|
||||
|
||||
credentials := &GoogleWorkspaceCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
// Create a new Admin SDK Directory service client
|
||||
adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service, err := admin.NewService(context.Background(),
|
||||
option.WithScopes(admin.AdminDirectoryUserReadonlyScope),
|
||||
option.WithCredentials(adminCredentials),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GoogleWorkspaceManager{
|
||||
usersService: service.Users,
|
||||
CustomerID: config.CustomerID,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from Google Workspace via ID.
|
||||
func (gm *GoogleWorkspaceManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
user, err := gm.usersService.Get(userID).Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
userData := parseGoogleWorkspaceUser(user)
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (gm *GoogleWorkspaceManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := gm.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
for index, user := range users {
|
||||
user.AppMetadata.WTAccountID = accountID
|
||||
users[index] = user
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (gm *GoogleWorkspaceManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
users, err := gm.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in a Google Workspace account filtered by customer ID.
|
||||
func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) {
|
||||
users := make([]*UserData, 0)
|
||||
pageToken := ""
|
||||
for {
|
||||
call := gm.usersService.List().Customer(gm.CustomerID).MaxResults(500)
|
||||
if pageToken != "" {
|
||||
call.PageToken(pageToken)
|
||||
}
|
||||
|
||||
resp, err := call.Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, user := range resp.Users {
|
||||
users = append(users, parseGoogleWorkspaceUser(user))
|
||||
}
|
||||
|
||||
pageToken = resp.NextPageToken
|
||||
if pageToken == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in Google Workspace and sends an invitation.
|
||||
func (gm *GoogleWorkspaceManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (gm *GoogleWorkspaceManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
user, err := gm.usersService.Get(email).Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
users = append(users, parseGoogleWorkspaceUser(user))
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (gm *GoogleWorkspaceManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from GoogleWorkspace.
|
||||
func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) error {
|
||||
if err := gm.usersService.Delete(userID).Do(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
|
||||
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
|
||||
// If that fails, it falls back to using the default Google credentials path.
|
||||
// It returns the retrieved credentials or an error if unsuccessful.
|
||||
func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) {
|
||||
log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key")
|
||||
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode service account key: %w", err)
|
||||
}
|
||||
|
||||
creds, err := google.CredentialsFromJSON(
|
||||
context.Background(),
|
||||
decodeKey,
|
||||
admin.AdminDirectoryUserReadonlyScope,
|
||||
)
|
||||
if err == nil {
|
||||
// No need to fallback to the default Google credentials path
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
|
||||
log.WithContext(ctx).Debug("falling back to default google credentials location")
|
||||
|
||||
creds, err = google.FindDefaultCredentials(
|
||||
context.Background(),
|
||||
admin.AdminDirectoryUserReadonlyScope,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
// parseGoogleWorkspaceUser parse google user to UserData.
|
||||
func parseGoogleWorkspaceUser(user *admin.User) *UserData {
|
||||
return &UserData{
|
||||
ID: user.Id,
|
||||
Email: user.PrimaryEmail,
|
||||
Name: user.Name.FullName,
|
||||
}
|
||||
}
|
||||
@@ -11,24 +11,25 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
// UnsetAccountID is a special key to map users without an account ID
|
||||
UnsetAccountID = "unset"
|
||||
)
|
||||
|
||||
// Manager idp manager interface
|
||||
// Note: NetBird is the single source of truth for authorization data (roles, account membership, invite status).
|
||||
// The IdP only stores identity information (email, name, credentials).
|
||||
type Manager interface {
|
||||
UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccount(ctx context.Context, accountId string) ([]*UserData, error)
|
||||
GetAllAccounts(ctx context.Context) (map[string][]*UserData, error)
|
||||
CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
// CreateUser creates a new user in the IdP. Returns basic user data (ID, email, name).
|
||||
CreateUser(ctx context.Context, email, name string) (*UserData, error)
|
||||
// GetUserDataByID retrieves user identity data from the IdP by user ID.
|
||||
GetUserDataByID(ctx context.Context, userId string) (*UserData, error)
|
||||
// GetUserByEmail searches for users by email address.
|
||||
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
|
||||
// GetAllUsers returns all users from the IdP for cache warming.
|
||||
GetAllUsers(ctx context.Context) ([]*UserData, error)
|
||||
// InviteUserByID resends an invitation to a user who hasn't completed signup.
|
||||
InviteUserByID(ctx context.Context, userID string) error
|
||||
// DeleteUser removes a user from the IdP.
|
||||
DeleteUser(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// ClientConfig defines common client configuration for all IdP manager
|
||||
// ClientConfig defines common client configuration for the IdP manager
|
||||
type ClientConfig struct {
|
||||
Issuer string
|
||||
TokenEndpoint string
|
||||
@@ -42,13 +43,10 @@ type ExtraConfig map[string]string
|
||||
|
||||
// Config an idp configuration struct to be loaded from management server's config file
|
||||
type Config struct {
|
||||
ManagerType string
|
||||
ClientConfig *ClientConfig
|
||||
ExtraConfig ExtraConfig
|
||||
Auth0ClientCredentials *Auth0ClientConfig
|
||||
AzureClientCredentials *AzureClientConfig
|
||||
KeycloakClientCredentials *KeycloakClientConfig
|
||||
ZitadelClientCredentials *ZitadelClientConfig
|
||||
ManagerType string
|
||||
ClientConfig *ClientConfig
|
||||
ExtraConfig ExtraConfig
|
||||
ZitadelClientCredentials *ZitadelClientConfig
|
||||
}
|
||||
|
||||
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
||||
@@ -67,11 +65,12 @@ type ManagerHelper interface {
|
||||
Unmarshal(data []byte, v interface{}) error
|
||||
}
|
||||
|
||||
// UserData represents identity information from the IdP.
|
||||
// Note: Authorization data (account membership, roles, invite status) is stored in NetBird's DB.
|
||||
type UserData struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
ID string `json:"user_id"`
|
||||
AppMetadata AppMetadata `json:"app_metadata"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
ID string `json:"user_id"`
|
||||
}
|
||||
|
||||
func (u *UserData) MarshalBinary() (data []byte, err error) {
|
||||
@@ -91,15 +90,6 @@ func (u *UserData) Unmarshal(data []byte) (err error) {
|
||||
return json.Unmarshal(data, &u)
|
||||
}
|
||||
|
||||
// AppMetadata user app metadata to associate with a profile
|
||||
type AppMetadata struct {
|
||||
// WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP
|
||||
// maps to wt_account_id when json.marshal
|
||||
WTAccountID string `json:"wt_account_id,omitempty"`
|
||||
WTPendingInvite *bool `json:"wt_pending_invite,omitempty"`
|
||||
WTInvitedBy string `json:"wt_invited_by_email,omitempty"`
|
||||
}
|
||||
|
||||
// JWTToken a JWT object that holds information of a token
|
||||
type JWTToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
@@ -109,7 +99,8 @@ type JWTToken struct {
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// NewManager returns a new idp manager based on the configuration that it receives
|
||||
// NewManager returns a new idp manager based on the configuration that it receives.
|
||||
// Only Zitadel is supported as the IdP manager.
|
||||
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
|
||||
if config.ClientConfig != nil {
|
||||
config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/")
|
||||
@@ -118,46 +109,6 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
||||
switch strings.ToLower(config.ManagerType) {
|
||||
case "none", "":
|
||||
return nil, nil //nolint:nilnil
|
||||
case "auth0":
|
||||
auth0ClientConfig := config.Auth0ClientCredentials
|
||||
if config.ClientConfig != nil {
|
||||
auth0ClientConfig = &Auth0ClientConfig{
|
||||
Audience: config.ExtraConfig["Audience"],
|
||||
AuthIssuer: config.ClientConfig.Issuer,
|
||||
ClientID: config.ClientConfig.ClientID,
|
||||
ClientSecret: config.ClientConfig.ClientSecret,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
}
|
||||
}
|
||||
|
||||
return NewAuth0Manager(*auth0ClientConfig, appMetrics)
|
||||
case "azure":
|
||||
azureClientConfig := config.AzureClientCredentials
|
||||
if config.ClientConfig != nil {
|
||||
azureClientConfig = &AzureClientConfig{
|
||||
ClientID: config.ClientConfig.ClientID,
|
||||
ClientSecret: config.ClientConfig.ClientSecret,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
ObjectID: config.ExtraConfig["ObjectId"],
|
||||
GraphAPIEndpoint: config.ExtraConfig["GraphApiEndpoint"],
|
||||
}
|
||||
}
|
||||
|
||||
return NewAzureManager(*azureClientConfig, appMetrics)
|
||||
case "keycloak":
|
||||
keycloakClientConfig := config.KeycloakClientCredentials
|
||||
if config.ClientConfig != nil {
|
||||
keycloakClientConfig = &KeycloakClientConfig{
|
||||
ClientID: config.ClientConfig.ClientID,
|
||||
ClientSecret: config.ClientConfig.ClientSecret,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
AdminEndpoint: config.ExtraConfig["AdminEndpoint"],
|
||||
}
|
||||
}
|
||||
|
||||
return NewKeycloakManager(*keycloakClientConfig, appMetrics)
|
||||
case "zitadel":
|
||||
zitadelClientConfig := config.ZitadelClientCredentials
|
||||
if config.ClientConfig != nil {
|
||||
@@ -172,42 +123,7 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
||||
}
|
||||
|
||||
return NewZitadelManager(*zitadelClientConfig, appMetrics)
|
||||
case "authentik":
|
||||
authentikConfig := AuthentikClientConfig{
|
||||
Issuer: config.ClientConfig.Issuer,
|
||||
ClientID: config.ClientConfig.ClientID,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
Username: config.ExtraConfig["Username"],
|
||||
Password: config.ExtraConfig["Password"],
|
||||
}
|
||||
return NewAuthentikManager(authentikConfig, appMetrics)
|
||||
case "okta":
|
||||
oktaClientConfig := OktaClientConfig{
|
||||
Issuer: config.ClientConfig.Issuer,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
}
|
||||
return NewOktaManager(oktaClientConfig, appMetrics)
|
||||
case "google":
|
||||
googleClientConfig := GoogleWorkspaceClientConfig{
|
||||
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
|
||||
CustomerID: config.ExtraConfig["CustomerId"],
|
||||
}
|
||||
return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics)
|
||||
case "jumpcloud":
|
||||
jumpcloudConfig := JumpCloudClientConfig{
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
}
|
||||
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
|
||||
case "pocketid":
|
||||
pocketidConfig := PocketIdClientConfig{
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
|
||||
}
|
||||
return NewPocketIdManager(pocketidConfig, appMetrics)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
||||
return nil, fmt.Errorf("unsupported IdP manager type: %s (only 'zitadel' is supported)", config.ManagerType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,257 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v1 "github.com/TheJumpCloud/jcapi-go/v1"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
contentType = "application/json"
|
||||
accept = "application/json"
|
||||
)
|
||||
|
||||
// JumpCloudManager JumpCloud manager client instance.
|
||||
type JumpCloudManager struct {
|
||||
client *v1.APIClient
|
||||
apiToken string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// JumpCloudClientConfig JumpCloud manager client configurations.
|
||||
type JumpCloudClientConfig struct {
|
||||
APIToken string
|
||||
}
|
||||
|
||||
// JumpCloudCredentials JumpCloud authentication information.
|
||||
type JumpCloudCredentials struct {
|
||||
clientConfig JumpCloudClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// NewJumpCloudManager creates a new instance of the JumpCloudManager.
|
||||
func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppMetrics) (*JumpCloudManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.APIToken == "" {
|
||||
return nil, fmt.Errorf("jumpCloud IdP configuration is incomplete, ApiToken is missing")
|
||||
}
|
||||
|
||||
client := v1.NewAPIClient(v1.NewConfiguration())
|
||||
credentials := &JumpCloudCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &JumpCloudManager{
|
||||
client: client,
|
||||
apiToken: config.APIToken,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the JumpCloud user API.
|
||||
func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
func (jm *JumpCloudManager) authenticationContext() context.Context {
|
||||
return context.WithValue(context.Background(), v1.ContextAPIKey, v1.APIKey{
|
||||
Key: jm.apiToken,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from JumpCloud via ID.
|
||||
func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
userData := parseJumpCloudUser(user)
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, user := range userList.Results {
|
||||
userData := parseJumpCloudUser(user)
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, user := range userList.Results {
|
||||
userData := parseJumpCloudUser(user)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
|
||||
func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
searchFilter := map[string]interface{}{
|
||||
"searchFilter": map[string]interface{}{
|
||||
"filter": []string{email},
|
||||
"fields": []string{"email"},
|
||||
},
|
||||
}
|
||||
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, searchFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
usersData := make([]*UserData, 0)
|
||||
for _, user := range userList.Results {
|
||||
usersData = append(usersData, parseJumpCloudUser(user))
|
||||
}
|
||||
|
||||
return usersData, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from jumpCloud directory
|
||||
func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
|
||||
authCtx := jm.authenticationContext()
|
||||
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseJumpCloudUser parse JumpCloud system user returned from API V1 to UserData.
|
||||
func parseJumpCloudUser(user v1.Systemuserreturn) *UserData {
|
||||
names := []string{user.Firstname, user.Middlename, user.Lastname}
|
||||
return &UserData{
|
||||
Email: user.Email,
|
||||
Name: strings.Join(names, " "),
|
||||
ID: user.Id,
|
||||
}
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewJumpCloudManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig JumpCloudClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := JumpCloudClientConfig{
|
||||
APIToken: "test123",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.APIToken = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing APIToken Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewJumpCloudManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,439 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// KeycloakManager keycloak manager client instance.
|
||||
type KeycloakManager struct {
|
||||
adminEndpoint string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// KeycloakClientConfig keycloak manager client configurations.
|
||||
type KeycloakClientConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AdminEndpoint string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// KeycloakCredentials keycloak authentication information.
|
||||
type KeycloakCredentials struct {
|
||||
clientConfig KeycloakClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// keycloakUserAttributes holds additional user data fields.
|
||||
type keycloakUserAttributes map[string][]string
|
||||
|
||||
// keycloakProfile represents a keycloak user profile response.
|
||||
type keycloakProfile struct {
|
||||
ID string `json:"id"`
|
||||
CreatedTimestamp int64 `json:"createdTimestamp"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Attributes keycloakUserAttributes `json:"attributes"`
|
||||
}
|
||||
|
||||
// NewKeycloakManager creates a new instance of the KeycloakManager.
|
||||
func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, clientID is missing")
|
||||
}
|
||||
|
||||
if config.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, ClientSecret is missing")
|
||||
}
|
||||
|
||||
if config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, TokenEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.AdminEndpoint == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, AdminEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
credentials := &KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &KeycloakManager{
|
||||
adminEndpoint: config.AdminEndpoint,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from keycloak.
|
||||
func (kc *KeycloakCredentials) jwtStillValid() bool {
|
||||
return !kc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(kc.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (kc *KeycloakCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kc.clientConfig.ClientID)
|
||||
data.Set("client_secret", kc.clientConfig.ClientSecret)
|
||||
data.Set("grant_type", kc.clientConfig.GrantType)
|
||||
|
||||
payload := strings.NewReader(data.Encode())
|
||||
req, err := http.NewRequest(http.MethodPost, kc.clientConfig.TokenEndpoint, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for keycloak idp manager")
|
||||
|
||||
resp, err := kc.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if kc.appMetrics != nil {
|
||||
kc.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unable to get keycloak token, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = kc.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = kc.helper.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the keycloak Management API.
|
||||
func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
kc.mux.Lock()
|
||||
defer kc.mux.Unlock()
|
||||
|
||||
if kc.appMetrics != nil {
|
||||
kc.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// reuse the token without requesting a new one if it is not expired,
|
||||
// and if expiry time is sufficient time available to make a request.
|
||||
if kc.jwtStillValid() {
|
||||
return kc.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := kc.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return kc.jwtToken, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
jwtToken, err := kc.parseRequestJWTResponse(resp.Body)
|
||||
if err != nil {
|
||||
return kc.jwtToken, err
|
||||
}
|
||||
|
||||
kc.jwtToken = jwtToken
|
||||
|
||||
return kc.jwtToken, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in keycloak Idp and sends an invite.
|
||||
func (km *KeycloakManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (km *KeycloakManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("email", email)
|
||||
q.Add("exact", "true")
|
||||
|
||||
body, err := km.get(ctx, "users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
profiles := make([]keycloakProfile, 0)
|
||||
err = km.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles {
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (km *KeycloakManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) {
|
||||
body, err := km.get(ctx, "users/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var profile keycloakProfile
|
||||
err = km.helper.Unmarshal(body, &profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return profile.userData(), nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given account profile.
|
||||
func (km *KeycloakManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles {
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (km *KeycloakManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, profile := range profiles {
|
||||
userData := profile.userData()
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (km *KeycloakManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (km *KeycloakManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Keycloak by user ID.
|
||||
func (km *KeycloakManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
jwtToken, err := km.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID))
|
||||
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
resp, err := km.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close() // nolint
|
||||
|
||||
// In the docs, they specified 200, but in the endpoints, they return 204
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloakProfile, error) {
|
||||
totalUsers, err := km.totalUsersCount(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := url.Values{}
|
||||
q.Add("max", fmt.Sprint(*totalUsers))
|
||||
|
||||
body, err := km.get(ctx, "users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profiles := make([]keycloakProfile, 0)
|
||||
err = km.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (km *KeycloakManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := km.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/%s?%s", km.adminEndpoint, resource, q.Encode())
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
resp, err := km.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// totalUsersCount returns the total count of all user created.
|
||||
// Used when fetching all registered accounts with pagination.
|
||||
func (km *KeycloakManager) totalUsersCount(ctx context.Context) (*int, error) {
|
||||
body, err := km.get(ctx, "users/count", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
count, err := strconv.Atoi(string(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &count, nil
|
||||
}
|
||||
|
||||
// userData construct user data from keycloak profile.
|
||||
func (kp keycloakProfile) userData() *UserData {
|
||||
return &UserData{
|
||||
Email: kp.Email,
|
||||
Name: kp.Username,
|
||||
ID: kp.ID,
|
||||
}
|
||||
}
|
||||
@@ -1,310 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewKeycloakManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig KeycloakClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := KeycloakClientConfig{
|
||||
ClientID: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
AdminEndpoint: "https://localhost:8080/auth/admin/realms/test123",
|
||||
TokenEndpoint: "https://localhost:8080/auth/realms/test123/protocol/openid-connect/token",
|
||||
GrantType: "client_credentials",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.ClientID = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing ClientID Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase3Config := defaultTestConfig
|
||||
testCase3Config.ClientSecret = ""
|
||||
|
||||
testCase3 := test{
|
||||
name: "Missing ClientSecret Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase4Config := defaultTestConfig
|
||||
testCase4Config.TokenEndpoint = ""
|
||||
|
||||
testCase4 := test{
|
||||
name: "Missing TokenEndpoint Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase5Config := defaultTestConfig
|
||||
testCase5Config.GrantType = ""
|
||||
|
||||
testCase5 := test{
|
||||
name: "Missing GrantType Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakRequestJWTToken(t *testing.T) {
|
||||
|
||||
type requestJWTTokenTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||
name: "Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
}
|
||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||
name: "Request Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputRespBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputRespBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err, "unable to read the response body")
|
||||
|
||||
jwtToken := JWTToken{}
|
||||
err = testCase.helper.Unmarshal(body, &jwtToken)
|
||||
assert.NoError(t, err, "unable to parse the json input")
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakParseRequestJWTResponse(t *testing.T) {
|
||||
type parseRequestJWTResponseTest struct {
|
||||
name string
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedToken string
|
||||
expectedExpiresIn int
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 100
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||
name: "Parse Good JWT Body",
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
expectedExpiresIn: exp,
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "no error was expected",
|
||||
}
|
||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||
name: "Parse Bad json JWT Body",
|
||||
inputRespBody: "",
|
||||
helper: JsonParser{},
|
||||
expectedToken: "",
|
||||
expectedExpiresIn: 0,
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "json error was expected",
|
||||
}
|
||||
|
||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakJwtStillValid(t *testing.T) {
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakAuthenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: nil,
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,52 +4,26 @@ import "context"
|
||||
|
||||
// MockIDP is a mock implementation of the IDP interface
|
||||
type MockIDP struct {
|
||||
UpdateUserAppMetadataFunc func(ctx context.Context, userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByIDFunc func(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccountFunc func(ctx context.Context, accountId string) ([]*UserData, error)
|
||||
GetAllAccountsFunc func(ctx context.Context) (map[string][]*UserData, error)
|
||||
CreateUserFunc func(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
|
||||
InviteUserByIDFunc func(ctx context.Context, userID string) error
|
||||
DeleteUserFunc func(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata is a mock implementation of the IDP interface UpdateUserAppMetadata method
|
||||
func (m *MockIDP) UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error {
|
||||
if m.UpdateUserAppMetadataFunc != nil {
|
||||
return m.UpdateUserAppMetadataFunc(ctx, userId, appMetadata)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
|
||||
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
|
||||
if m.GetUserDataByIDFunc != nil {
|
||||
return m.GetUserDataByIDFunc(ctx, userId, appMetadata)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAccount is a mock implementation of the IDP interface GetAccount method
|
||||
func (m *MockIDP) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
|
||||
if m.GetAccountFunc != nil {
|
||||
return m.GetAccountFunc(ctx, accountId)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts is a mock implementation of the IDP interface GetAllAccounts method
|
||||
func (m *MockIDP) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
if m.GetAllAccountsFunc != nil {
|
||||
return m.GetAllAccountsFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
CreateUserFunc func(ctx context.Context, email, name string) (*UserData, error)
|
||||
GetUserDataByIDFunc func(ctx context.Context, userId string) (*UserData, error)
|
||||
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
|
||||
GetAllUsersFunc func(ctx context.Context) ([]*UserData, error)
|
||||
InviteUserByIDFunc func(ctx context.Context, userID string) error
|
||||
DeleteUserFunc func(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// CreateUser is a mock implementation of the IDP interface CreateUser method
|
||||
func (m *MockIDP) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
func (m *MockIDP) CreateUser(ctx context.Context, email, name string) (*UserData, error) {
|
||||
if m.CreateUserFunc != nil {
|
||||
return m.CreateUserFunc(ctx, email, name, accountID, invitedByEmail)
|
||||
return m.CreateUserFunc(ctx, email, name)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
|
||||
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string) (*UserData, error) {
|
||||
if m.GetUserDataByIDFunc != nil {
|
||||
return m.GetUserDataByIDFunc(ctx, userId)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
@@ -62,6 +36,14 @@ func (m *MockIDP) GetUserByEmail(ctx context.Context, email string) ([]*UserData
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAllUsers is a mock implementation of the IDP interface GetAllUsers method
|
||||
func (m *MockIDP) GetAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
if m.GetAllUsersFunc != nil {
|
||||
return m.GetAllUsersFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// InviteUserByID is a mock implementation of the IDP interface InviteUserByID method
|
||||
func (m *MockIDP) InviteUserByID(ctx context.Context, userID string) error {
|
||||
if m.InviteUserByIDFunc != nil {
|
||||
|
||||
@@ -1,306 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/okta/okta-sdk-golang/v2/okta"
|
||||
"github.com/okta/okta-sdk-golang/v2/okta/query"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// OktaManager okta manager client instance.
|
||||
type OktaManager struct {
|
||||
client *okta.Client
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// OktaClientConfig okta manager client configurations.
|
||||
type OktaClientConfig struct {
|
||||
APIToken string
|
||||
Issuer string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// OktaCredentials okta authentication information.
|
||||
type OktaCredentials struct {
|
||||
clientConfig OktaClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// NewOktaManager creates a new instance of the OktaManager.
|
||||
func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*OktaManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
config.Issuer = baseURL(config.Issuer)
|
||||
|
||||
if config.APIToken == "" {
|
||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, APIToken is missing")
|
||||
}
|
||||
|
||||
if config.Issuer == "" {
|
||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, Issuer is missing")
|
||||
}
|
||||
|
||||
if config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, TokenEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
_, client, err := okta.NewClient(context.Background(),
|
||||
okta.WithOrgUrl(config.Issuer),
|
||||
okta.WithToken(config.APIToken),
|
||||
okta.WithHttpClientPtr(httpClient),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials := &OktaCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &OktaManager{
|
||||
client: client,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the okta user API.
|
||||
func (oc *OktaCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in okta Idp and sends an invitation.
|
||||
func (om *OktaManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (om *OktaManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
user, resp, err := om.client.User.GetUser(context.Background(), userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
userData, err := parseOktaUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (om *OktaManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
user, resp, err := om.client.User.GetUser(context.Background(), url.QueryEscape(email))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
||||
}
|
||||
|
||||
userData, err := parseOktaUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users := make([]*UserData, 0)
|
||||
users = append(users, userData)
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (om *OktaManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := om.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
for index, user := range users {
|
||||
user.AppMetadata.WTAccountID = accountID
|
||||
users[index] = user
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (om *OktaManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
users, err := om.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in an Okta account.
|
||||
func (om *OktaManager) getAllUsers() ([]*UserData, error) {
|
||||
qp := query.NewQueryParams(query.WithLimit(200))
|
||||
userList, resp, err := om.client.User.ListUsers(context.Background(), qp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
for resp.HasNextPage() {
|
||||
paginatedUsers := make([]*okta.User, 0)
|
||||
resp, err = resp.Next(context.Background(), &paginatedUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
userList = append(userList, paginatedUsers...)
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0, len(userList))
|
||||
for _, user := range userList {
|
||||
userData, err := parseOktaUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (om *OktaManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (om *OktaManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Okta
|
||||
func (om *OktaManager) DeleteUser(_ context.Context, userID string) error {
|
||||
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseOktaUser parse okta user to UserData.
|
||||
func parseOktaUser(user *okta.User) (*UserData, error) {
|
||||
var oktaUser struct {
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return nil, fmt.Errorf("invalid okta user")
|
||||
}
|
||||
|
||||
if user.Profile != nil {
|
||||
helper := JsonParser{}
|
||||
buf, err := helper.Marshal(*user.Profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = helper.Unmarshal(buf, &oktaUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &UserData{
|
||||
Email: oktaUser.Email,
|
||||
Name: strings.Join([]string{oktaUser.FirstName, oktaUser.LastName}, " "),
|
||||
ID: user.Id,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/okta/okta-sdk-golang/v2/okta"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseOktaUser(t *testing.T) {
|
||||
type parseOktaUserTest struct {
|
||||
name string
|
||||
inputProfile *okta.User
|
||||
expectedUserData *UserData
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
}
|
||||
|
||||
parseOktaTestCase1 := parseOktaUserTest{
|
||||
name: "Good Request",
|
||||
inputProfile: &okta.User{
|
||||
Id: "123",
|
||||
Profile: &okta.UserProfile{
|
||||
"email": "test@example.com",
|
||||
"firstName": "John",
|
||||
"lastName": "Doe",
|
||||
},
|
||||
},
|
||||
expectedUserData: &UserData{
|
||||
Email: "test@example.com",
|
||||
Name: "John Doe",
|
||||
ID: "123",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "456",
|
||||
},
|
||||
},
|
||||
assertErrFunc: assert.NoError,
|
||||
}
|
||||
|
||||
parseOktaTestCase2 := parseOktaUserTest{
|
||||
name: "Invalid okta user",
|
||||
inputProfile: nil,
|
||||
expectedUserData: nil,
|
||||
assertErrFunc: assert.Error,
|
||||
}
|
||||
|
||||
for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
userData, err := parseOktaUser(testCase.inputProfile)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFunc)
|
||||
|
||||
if err == nil {
|
||||
assert.True(t, userDataEqual(testCase.expectedUserData, userData), "user data should match")
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// userDataEqual helper function to compare UserData structs for equality.
|
||||
func userDataEqual(a, b *UserData) bool {
|
||||
if a.Email != b.Email || a.Name != b.Name || a.ID != b.ID {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,384 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
type PocketIdManager struct {
|
||||
managementEndpoint string
|
||||
apiToken string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
type pocketIdCustomClaimDto struct {
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type pocketIdUserDto struct {
|
||||
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
|
||||
Disabled bool `json:"disabled"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstName"`
|
||||
ID string `json:"id"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
LastName string `json:"lastName"`
|
||||
LdapID string `json:"ldapId"`
|
||||
Locale string `json:"locale"`
|
||||
UserGroups []pocketIdUserGroupDto `json:"userGroups"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type pocketIdUserCreateDto struct {
|
||||
Disabled bool `json:"disabled,omitempty"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstName"`
|
||||
IsAdmin bool `json:"isAdmin,omitempty"`
|
||||
LastName string `json:"lastName,omitempty"`
|
||||
Locale string `json:"locale,omitempty"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type pocketIdPaginatedUserDto struct {
|
||||
Data []pocketIdUserDto `json:"data"`
|
||||
Pagination pocketIdPaginationDto `json:"pagination"`
|
||||
}
|
||||
|
||||
type pocketIdPaginationDto struct {
|
||||
CurrentPage int `json:"currentPage"`
|
||||
ItemsPerPage int `json:"itemsPerPage"`
|
||||
TotalItems int `json:"totalItems"`
|
||||
TotalPages int `json:"totalPages"`
|
||||
}
|
||||
|
||||
func (p *pocketIdUserDto) userData() *UserData {
|
||||
return &UserData{
|
||||
Email: p.Email,
|
||||
Name: p.DisplayName,
|
||||
ID: p.ID,
|
||||
AppMetadata: AppMetadata{},
|
||||
}
|
||||
}
|
||||
|
||||
type pocketIdUserGroupDto struct {
|
||||
CreatedAt string `json:"createdAt"`
|
||||
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
|
||||
FriendlyName string `json:"friendlyName"`
|
||||
ID string `json:"id"`
|
||||
LdapID string `json:"ldapId"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ManagementEndpoint == "" {
|
||||
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.APIToken == "" {
|
||||
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing")
|
||||
}
|
||||
|
||||
credentials := &PocketIdCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &PocketIdManager{
|
||||
managementEndpoint: config.ManagementEndpoint,
|
||||
apiToken: config.APIToken,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) {
|
||||
var MethodsWithBody = []string{http.MethodPost, http.MethodPut}
|
||||
if !slices.Contains(MethodsWithBody, method) && body != "" {
|
||||
return nil, fmt.Errorf("Body provided to unsupported method: %s", method)
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource)
|
||||
if query != nil {
|
||||
reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode())
|
||||
}
|
||||
var req *http.Request
|
||||
var err error
|
||||
if body != "" {
|
||||
req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body))
|
||||
} else {
|
||||
req, err = http.NewRequestWithContext(ctx, method, reqURL, nil)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("X-API-KEY", p.apiToken)
|
||||
|
||||
if body != "" {
|
||||
req.Header.Add("content-type", "application/json")
|
||||
req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength))
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// getAllUsersPaginated fetches all users from PocketID API using pagination
|
||||
func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) {
|
||||
var allUsers []pocketIdUserDto
|
||||
currentPage := 1
|
||||
|
||||
for {
|
||||
params := url.Values{}
|
||||
// Copy existing search parameters
|
||||
for key, values := range searchParams {
|
||||
params[key] = values
|
||||
}
|
||||
|
||||
params.Set("pagination[limit]", "100")
|
||||
params.Set("pagination[page]", fmt.Sprintf("%d", currentPage))
|
||||
|
||||
body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var profiles pocketIdPaginatedUserDto
|
||||
err = p.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allUsers = append(allUsers, profiles.Data...)
|
||||
|
||||
// Check if we've reached the last page
|
||||
if currentPage >= profiles.Pagination.TotalPages {
|
||||
break
|
||||
}
|
||||
|
||||
currentPage++
|
||||
}
|
||||
|
||||
return allUsers, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
|
||||
body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var user pocketIdUserDto
|
||||
err = p.helper.Unmarshal(body, &user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData := user.userData()
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
|
||||
// Get all users using pagination
|
||||
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range allUsers {
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata.WTAccountID = accountId
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
// Get all users using pagination
|
||||
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, profile := range allUsers {
|
||||
userData := profile.userData()
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
firstLast := strings.Split(name, " ")
|
||||
|
||||
createUser := pocketIdUserCreateDto{
|
||||
Disabled: false,
|
||||
DisplayName: name,
|
||||
Email: email,
|
||||
FirstName: firstLast[0],
|
||||
LastName: firstLast[1],
|
||||
Username: firstLast[0] + "." + firstLast[1],
|
||||
}
|
||||
payload, err := p.helper.Marshal(createUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var newUser pocketIdUserDto
|
||||
err = p.helper.Unmarshal(body, &newUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
var pending bool = true
|
||||
ret := &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: newUser.ID,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &pending,
|
||||
WTInvitedBy: invitedByEmail,
|
||||
},
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
params := url.Values{
|
||||
// This value a
|
||||
"search": []string{email},
|
||||
}
|
||||
body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
var profiles struct{ data []pocketIdUserDto }
|
||||
err = p.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles.data {
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error {
|
||||
_, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
_, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ Manager = (*PocketIdManager)(nil)
|
||||
|
||||
type PocketIdClientConfig struct {
|
||||
APIToken string
|
||||
ManagementEndpoint string
|
||||
}
|
||||
|
||||
type PocketIdCredentials struct {
|
||||
clientConfig PocketIdClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
var _ ManagerCredentials = (*PocketIdCredentials)(nil)
|
||||
|
||||
func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
@@ -1,137 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewPocketIdManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig PocketIdClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := PocketIdClientConfig{
|
||||
APIToken: "api_token",
|
||||
ManagementEndpoint: "http://localhost",
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
},
|
||||
{
|
||||
name: "Missing ManagementEndpoint",
|
||||
inputConfig: PocketIdClientConfig{
|
||||
APIToken: defaultTestConfig.APIToken,
|
||||
ManagementEndpoint: "",
|
||||
},
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
},
|
||||
{
|
||||
name: "Missing APIToken",
|
||||
inputConfig: PocketIdClientConfig{
|
||||
APIToken: "",
|
||||
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
|
||||
},
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
|
||||
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPocketID_GetUserDataByID(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
md := AppMetadata{WTAccountID: "acc1"}
|
||||
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "u1", got.ID)
|
||||
assert.Equal(t, "user1@example.com", got.Email)
|
||||
assert.Equal(t, "User One", got.Name)
|
||||
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
|
||||
}
|
||||
|
||||
func TestPocketID_GetAccount_WithPagination(t *testing.T) {
|
||||
// Single page response with two users
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
users, err := mgr.GetAccount(context.Background(), "accX")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, 2)
|
||||
assert.Equal(t, "u1", users[0].ID)
|
||||
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
|
||||
assert.Equal(t, "u2", users[1].ID)
|
||||
}
|
||||
|
||||
func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
accounts, err := mgr.GetAllAccounts(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, accounts[UnsetAccountID], 2)
|
||||
}
|
||||
|
||||
func TestPocketID_CreateUser(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "newid", ud.ID)
|
||||
assert.Equal(t, "new@example.com", ud.Email)
|
||||
assert.Equal(t, "New User", ud.Name)
|
||||
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
|
||||
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
|
||||
assert.True(t, *ud.AppMetadata.WTPendingInvite)
|
||||
}
|
||||
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
|
||||
}
|
||||
|
||||
func TestPocketID_InviteAndDeleteUser(t *testing.T) {
|
||||
// Same mock for both calls; returns OK with empty JSON
|
||||
client := &mockHTTPClient{code: 200, resBody: `{}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
err = mgr.InviteUserByID(context.Background(), "u1")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.DeleteUser(context.Background(), "u1")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
47
management/server/idp/test_util_test.go
Normal file
47
management/server/idp/test_util_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockHTTPClient is a mock implementation of ManagerHTTPClient for testing
|
||||
type mockHTTPClient struct {
|
||||
code int
|
||||
resBody string
|
||||
reqBody string
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
if c.err != nil {
|
||||
return nil, c.err
|
||||
}
|
||||
|
||||
if req.Body != nil {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
c.reqBody = string(body)
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: c.code,
|
||||
Body: io.NopCloser(bytes.NewReader([]byte(c.resBody))),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newTestJWT creates a test JWT token with the given expiration time in seconds
|
||||
func newTestJWT(t *testing.T, expiresIn int) string {
|
||||
t.Helper()
|
||||
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
|
||||
exp := time.Now().Add(time.Duration(expiresIn) * time.Second).Unix()
|
||||
payload := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf(`{"exp":%d}`, exp)))
|
||||
signature := base64.RawURLEncoding.EncodeToString([]byte("test-signature"))
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
@@ -319,14 +319,19 @@ func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel.
|
||||
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
// Note: Authorization data (account membership, invite status) is stored in NetBird's DB, not in the IdP.
|
||||
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name string) (*UserData, error) {
|
||||
firstLast := strings.SplitN(name, " ", 2)
|
||||
lastName := firstLast[0]
|
||||
if len(firstLast) > 1 {
|
||||
lastName = firstLast[1]
|
||||
}
|
||||
|
||||
var addUser = map[string]any{
|
||||
"userName": email,
|
||||
"profile": map[string]string{
|
||||
"firstName": firstLast[0],
|
||||
"lastName": firstLast[0],
|
||||
"lastName": lastName,
|
||||
"displayName": name,
|
||||
},
|
||||
"email": map[string]any{
|
||||
@@ -357,18 +362,11 @@ func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var pending bool = true
|
||||
ret := &UserData{
|
||||
return &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: newUser.UserId,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &pending,
|
||||
WTInvitedBy: invitedByEmail,
|
||||
},
|
||||
}
|
||||
return ret, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
@@ -413,7 +411,7 @@ func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from zitadel via ID.
|
||||
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string) (*UserData, error) {
|
||||
body, err := zm.get(ctx, "users/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -429,43 +427,12 @@ func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, ap
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData := profile.User.userData()
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
return profile.User.userData(), nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (zm *ZitadelManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
body, err := zm.post(ctx, "users/_search", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
var profiles struct{ Result []zitadelProfile }
|
||||
err = zm.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles.Result {
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
// GetAllUsers returns all users from the IdP.
|
||||
// Used for cache warming - NetBird matches these against its own user database.
|
||||
func (zm *ZitadelManager) GetAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
body, err := zm.post(ctx, "users/_search", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -481,19 +448,12 @@ func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*Use
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
users := make([]*UserData, 0, len(profiles.Result))
|
||||
for _, profile := range profiles.Result {
|
||||
userData := profile.userData()
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
// Metadata values are base64 encoded.
|
||||
func (zm *ZitadelManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
return users, nil
|
||||
}
|
||||
|
||||
type inviteUserRequest struct {
|
||||
|
||||
@@ -288,16 +288,14 @@ func TestZitadelAuthenticate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestZitadelProfile(t *testing.T) {
|
||||
type azureProfileTest struct {
|
||||
type zitadelProfileTest struct {
|
||||
name string
|
||||
invite bool
|
||||
inputProfile zitadelProfile
|
||||
expectedUserData UserData
|
||||
}
|
||||
|
||||
azureProfileTestCase1 := azureProfileTest{
|
||||
name: "User Request",
|
||||
invite: false,
|
||||
zitadelProfileTestCase1 := zitadelProfileTest{
|
||||
name: "User Request",
|
||||
inputProfile: zitadelProfile{
|
||||
ID: "test1",
|
||||
State: "USER_STATE_ACTIVE",
|
||||
@@ -322,15 +320,11 @@ func TestZitadelProfile(t *testing.T) {
|
||||
ID: "test1",
|
||||
Name: "ZITADEL Admin",
|
||||
Email: "test1@mail.com",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
azureProfileTestCase2 := azureProfileTest{
|
||||
name: "Service User Request",
|
||||
invite: true,
|
||||
zitadelProfileTestCase2 := zitadelProfileTest{
|
||||
name: "Service User Request",
|
||||
inputProfile: zitadelProfile{
|
||||
ID: "test2",
|
||||
State: "USER_STATE_ACTIVE",
|
||||
@@ -345,15 +339,11 @@ func TestZitadelProfile(t *testing.T) {
|
||||
ID: "test2",
|
||||
Name: "machine",
|
||||
Email: "machine",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2} {
|
||||
for _, testCase := range []zitadelProfileTest{zitadelProfileTestCase1, zitadelProfileTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
||||
userData := testCase.inputProfile.userData()
|
||||
|
||||
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -487,7 +486,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
||||
if am.idpManager != nil {
|
||||
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||
userdata, err := am.idpManager.GetUserDataByID(ctx, userID)
|
||||
if err == nil && userdata != nil {
|
||||
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
||||
}
|
||||
|
||||
9974
management/server/types/testdata/networkmap_golden.json
vendored
Normal file
9974
management/server/types/testdata/networkmap_golden.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9974
management/server/types/testdata/networkmap_golden_new.json
vendored
Normal file
9974
management/server/types/testdata/networkmap_golden_new.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9862
management/server/types/testdata/networkmap_golden_new_with_deleted_router.json
vendored
Normal file
9862
management/server/types/testdata/networkmap_golden_new_with_deleted_router.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
10086
management/server/types/testdata/networkmap_golden_new_with_onpeeradded.json
vendored
Normal file
10086
management/server/types/testdata/networkmap_golden_new_with_onpeeradded.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
10086
management/server/types/testdata/networkmap_golden_new_with_onpeeradded_router.json
vendored
Normal file
10086
management/server/types/testdata/networkmap_golden_new_with_onpeeradded_router.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9862
management/server/types/testdata/networkmap_golden_new_with_onpeerdeleted.json
vendored
Normal file
9862
management/server/types/testdata/networkmap_golden_new_with_onpeerdeleted.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9862
management/server/types/testdata/networkmap_golden_with_deleted_peer.json
vendored
Normal file
9862
management/server/types/testdata/networkmap_golden_with_deleted_peer.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9862
management/server/types/testdata/networkmap_golden_with_deleted_router_peer.json
vendored
Normal file
9862
management/server/types/testdata/networkmap_golden_with_deleted_router_peer.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
10086
management/server/types/testdata/networkmap_golden_with_new_peer.json
vendored
Normal file
10086
management/server/types/testdata/networkmap_golden_with_new_peer.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
10086
management/server/types/testdata/networkmap_golden_with_new_router.json
vendored
Normal file
10086
management/server/types/testdata/networkmap_golden_with_new_router.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@@ -85,6 +85,8 @@ type User struct {
|
||||
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||
Blocked bool
|
||||
// PendingInvite indicates whether the user has accepted their invite and logged in
|
||||
PendingInvite bool
|
||||
// PendingApproval indicates whether the user requires approval before being activated
|
||||
PendingApproval bool
|
||||
// LastLogin is the last time the user logged in to IdP
|
||||
@@ -162,7 +164,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
}
|
||||
|
||||
userStatus := UserStatusActive
|
||||
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
|
||||
if u.PendingInvite {
|
||||
userStatus = UserStatusInvited
|
||||
}
|
||||
|
||||
@@ -199,6 +201,7 @@ func (u *User) Copy() *User {
|
||||
ServiceUserName: u.ServiceUserName,
|
||||
PATs: pats,
|
||||
Blocked: u.Blocked,
|
||||
PendingInvite: u.PendingInvite,
|
||||
PendingApproval: u.PendingApproval,
|
||||
LastLogin: u.LastLogin,
|
||||
CreatedAt: u.CreatedAt,
|
||||
|
||||
@@ -117,6 +117,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
Issued: invite.Issued,
|
||||
IntegrationReference: invite.IntegrationReference,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
PendingInvite: true, // User hasn't accepted invite yet
|
||||
}
|
||||
|
||||
if err = am.Store.SaveUser(ctx, newUser); err != nil {
|
||||
@@ -169,7 +170,7 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
|
||||
}
|
||||
|
||||
return am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviterUser.Email)
|
||||
return am.idpManager.CreateUser(ctx, invite.Email, invite.Name)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||
@@ -285,8 +286,8 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// check if the user is already registered with this ID
|
||||
user, err := am.lookupUserInCache(ctx, targetUserID, accountID)
|
||||
// Get user from NetBird's database (source of truth for authorization data)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -295,18 +296,17 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
|
||||
return status.Errorf(status.NotFound, "user account %s doesn't exist", targetUserID)
|
||||
}
|
||||
|
||||
// check if user account is already invited and account is not activated
|
||||
pendingInvite := user.AppMetadata.WTPendingInvite
|
||||
if pendingInvite == nil || !*pendingInvite {
|
||||
// Check if user is still pending invite (hasn't activated their account)
|
||||
if !user.PendingInvite {
|
||||
return status.Errorf(status.PreconditionFailed, "can't invite a user with an activated NetBird account")
|
||||
}
|
||||
|
||||
err = am.idpManager.InviteUserByID(ctx, user.ID)
|
||||
err = am.idpManager.InviteUserByID(ctx, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, user.ID, accountID, activity.UserInvited, nil)
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserInvited, nil)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1011,17 +1011,15 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
|
||||
func (am *DefaultAccountManager) deleteUserFromIDP(ctx context.Context, targetUserID, accountID string) error {
|
||||
if am.userDeleteFromIDPEnabled {
|
||||
log.WithContext(ctx).Debugf("user %s deleted from IdP", targetUserID)
|
||||
log.WithContext(ctx).Debugf("deleting user %s from IdP", targetUserID)
|
||||
err := am.idpManager.DeleteUser(ctx, targetUserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err)
|
||||
}
|
||||
} else {
|
||||
err := am.idpManager.UpdateUserAppMetadata(ctx, targetUserID, idp.AppMetadata{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove user %s app metadata in IdP: %s", targetUserID, err)
|
||||
}
|
||||
}
|
||||
// Note: If userDeleteFromIDPEnabled is false, the user remains in IdP but is removed from NetBird.
|
||||
// This allows the user to re-authenticate if they're re-invited later.
|
||||
|
||||
err := am.removeUserFromCache(ctx, accountID, targetUserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("remove user from account (%q) cache failed with error: %v", accountID, err)
|
||||
@@ -1095,7 +1093,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
||||
if !isNil(am.idpManager) {
|
||||
// Delete if the user already exists in the IdP. Necessary in cases where a user account
|
||||
// was created where a user account was provisioned but the user did not sign in
|
||||
_, err := am.idpManager.GetUserDataByID(ctx, targetUserInfo.ID, idp.AppMetadata{WTAccountID: accountID})
|
||||
_, err := am.idpManager.GetUserDataByID(ctx, targetUserInfo.ID)
|
||||
if err == nil {
|
||||
err = am.deleteUserFromIDP(ctx, targetUserInfo.ID, accountID)
|
||||
if err != nil {
|
||||
|
||||
@@ -563,7 +563,7 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
}
|
||||
|
||||
idpMock := idp.MockIDP{
|
||||
CreateUserFunc: func(_ context.Context, email, name, accountID, invitedByEmail string) (*idp.UserData, error) {
|
||||
CreateUserFunc: func(_ context.Context, email, name string) (*idp.UserData, error) {
|
||||
newData := &idp.UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
@@ -574,7 +574,7 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
|
||||
return newData, nil
|
||||
},
|
||||
GetAccountFunc: func(_ context.Context, accountId string) ([]*idp.UserData, error) {
|
||||
GetAllUsersFunc: func(_ context.Context) ([]*idp.UserData, error) {
|
||||
return mockData, nil
|
||||
},
|
||||
}
|
||||
@@ -1068,7 +1068,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
idpManager: &idp.GoogleWorkspaceManager{}, // empty manager
|
||||
idpManager: &idp.MockIDP{}, // empty manager
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user