mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-29 13:46:41 +00:00
Compare commits
2 Commits
fix-darwin
...
feat/integ
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9291e3134b | ||
|
|
eb578146e4 |
352
docs/plans/self-service-auth-plan.mdx
Normal file
352
docs/plans/self-service-auth-plan.mdx
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
# Zitadel Integration Plan
|
||||||
|
|
||||||
|
This plan is split into stages. Each stage is self-contained and can be implemented and tested independently.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 1: Infrastructure Setup [COMPLETED]
|
||||||
|
|
||||||
|
**Goal**: Add Zitadel to docker-compose with first-boot initialization (single container, SQLite).
|
||||||
|
|
||||||
|
### Design Principles
|
||||||
|
- **KISS**: Single Zitadel container with SQLite (no separate PostgreSQL or init containers)
|
||||||
|
- **DRY**: Leverage Zitadel's built-in first-instance configuration via environment variables
|
||||||
|
- **Clean**: Credentials generated in configure.sh and printed directly
|
||||||
|
|
||||||
|
### Files Created
|
||||||
|
|
||||||
|
#### `infrastructure_files/Caddyfile.tmpl`
|
||||||
|
Reverse proxy configuration routing:
|
||||||
|
- `/api/*`, `/management.*` → Management server
|
||||||
|
- `/relay*` → Relay
|
||||||
|
- `/signalexchange.*`, `/signal*` → Signal
|
||||||
|
- `/oauth/*`, `/oidc/*`, `/.well-known/*`, `/ui/*` → Zitadel
|
||||||
|
- `/*` → Dashboard
|
||||||
|
|
||||||
|
### Files Modified
|
||||||
|
|
||||||
|
#### `infrastructure_files/docker-compose.yml.tmpl`
|
||||||
|
Added single Zitadel service using SQLite:
|
||||||
|
```yaml
|
||||||
|
zitadel:
|
||||||
|
image: ghcr.io/zitadel/zitadel:$ZITADEL_TAG
|
||||||
|
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
||||||
|
environment:
|
||||||
|
- ZITADEL_DATABASE_SQLITE_PATH=/data/zitadel.db
|
||||||
|
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_USERNAME=$ZITADEL_ADMIN_USERNAME
|
||||||
|
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_PASSWORD=$ZITADEL_ADMIN_PASSWORD
|
||||||
|
# ... service account config via env vars
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `infrastructure_files/management.json.tmpl`
|
||||||
|
Updated IdpManagerConfig to default to Zitadel.
|
||||||
|
|
||||||
|
#### `infrastructure_files/base.setup.env`
|
||||||
|
Added Zitadel-specific environment variables:
|
||||||
|
- `ZITADEL_TAG` (default: v4.7.6)
|
||||||
|
- `ZITADEL_MASTERKEY` (auto-generated)
|
||||||
|
- `ZITADEL_ADMIN_USERNAME`, `ZITADEL_ADMIN_PASSWORD` (auto-generated)
|
||||||
|
- `ZITADEL_EXTERNALSECURE`, `ZITADEL_EXTERNALPORT`, `ZITADEL_TLS_MODE`
|
||||||
|
|
||||||
|
#### `infrastructure_files/configure.sh`
|
||||||
|
- Generates Zitadel masterkey if not set
|
||||||
|
- Generates admin credentials if not set
|
||||||
|
- Prints credentials to stdout during configuration
|
||||||
|
|
||||||
|
### Deliverable
|
||||||
|
Running `./configure.sh && cd artifacts && docker compose up -d` starts full stack with Zitadel. Admin credentials printed during configuration.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 2: Simplify IdP Integration
|
||||||
|
|
||||||
|
**Goal**: Make NetBird the single source of truth for user authorization data. Zitadel handles authentication only.
|
||||||
|
|
||||||
|
### Design Principles
|
||||||
|
- **Single source of truth**: NetBird DB stores all authorization data (roles, account membership, invite status)
|
||||||
|
- **Clean separation**: Zitadel stores identity only (email, name, password)
|
||||||
|
- **No duplicate data**: Remove `wt_account_id` and `wt_pending_invite` from IdP metadata
|
||||||
|
|
||||||
|
### Files to Modify
|
||||||
|
|
||||||
|
#### `management/server/types/user.go`
|
||||||
|
Add `PendingInvite` field to User struct:
|
||||||
|
```go
|
||||||
|
type User struct {
|
||||||
|
// ... existing fields ...
|
||||||
|
PendingInvite bool // NEW: tracks if user has accepted invite
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Update `ToUserInfo()` to use local field instead of IdP metadata:
|
||||||
|
```go
|
||||||
|
// Before
|
||||||
|
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
|
||||||
|
userStatus = UserStatusInvited
|
||||||
|
}
|
||||||
|
|
||||||
|
// After
|
||||||
|
if u.PendingInvite {
|
||||||
|
userStatus = UserStatusInvited
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `management/server/idp/idp.go`
|
||||||
|
Simplify `Manager` interface:
|
||||||
|
```go
|
||||||
|
type Manager interface {
|
||||||
|
// Simplified - no more accountID or invitedByEmail params
|
||||||
|
CreateUser(ctx context.Context, email, name string) (*UserData, error)
|
||||||
|
GetUserDataByID(ctx context.Context, userId string) (*UserData, error)
|
||||||
|
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
|
||||||
|
InviteUserByID(ctx context.Context, userID string) error
|
||||||
|
DeleteUser(ctx context.Context, userID string) error
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Remove:
|
||||||
|
- `UpdateUserAppMetadata()` - no longer needed
|
||||||
|
- `GetAccount()` - account membership now in NetBird DB
|
||||||
|
- `GetAllAccounts()` - account membership now in NetBird DB
|
||||||
|
- `AppMetadata` struct fields (`WTAccountID`, `WTPendingInvite`, `WTInvitedBy`)
|
||||||
|
|
||||||
|
#### `management/server/idp/zitadel.go`
|
||||||
|
- Update `CreateUser()` to not set metadata
|
||||||
|
- Remove `UpdateUserAppMetadata()` implementation
|
||||||
|
- Simplify `GetUserDataByID()` to not expect metadata
|
||||||
|
|
||||||
|
#### `management/server/user.go`
|
||||||
|
Update user creation flow:
|
||||||
|
1. `inviteNewUser()`: Set `PendingInvite = true` when creating user
|
||||||
|
2. On first login: Set `PendingInvite = false`
|
||||||
|
|
||||||
|
### Data Flow (After)
|
||||||
|
|
||||||
|
```
|
||||||
|
Invite User:
|
||||||
|
1. Admin invites user@example.com via NetBird UI
|
||||||
|
2. NetBird calls Zitadel: CreateUser(email, name) // No metadata
|
||||||
|
3. Zitadel creates user, sends invite email
|
||||||
|
4. NetBird creates User in its DB:
|
||||||
|
- Id = Zitadel user ID
|
||||||
|
- AccountID = current account
|
||||||
|
- Role = invited role
|
||||||
|
- PendingInvite = true
|
||||||
|
|
||||||
|
First Login:
|
||||||
|
1. User clicks invite, sets password in Zitadel
|
||||||
|
2. User logs into NetBird Dashboard
|
||||||
|
3. NetBird updates User in its DB:
|
||||||
|
- PendingInvite = false
|
||||||
|
- LastLogin = now
|
||||||
|
```
|
||||||
|
|
||||||
|
### Deliverable
|
||||||
|
- NetBird is single source of truth for authorization
|
||||||
|
- Zitadel stores identity only (email, name, credentials)
|
||||||
|
- No metadata sync issues between systems
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 3: Remove Legacy IdP Managers
|
||||||
|
|
||||||
|
**Goal**: Remove all non-Zitadel IdP implementations.
|
||||||
|
|
||||||
|
### Files to Delete
|
||||||
|
- `management/server/idp/auth0.go`
|
||||||
|
- `management/server/idp/auth0_test.go`
|
||||||
|
- `management/server/idp/azure.go`
|
||||||
|
- `management/server/idp/azure_test.go`
|
||||||
|
- `management/server/idp/keycloak.go`
|
||||||
|
- `management/server/idp/keycloak_test.go`
|
||||||
|
- `management/server/idp/okta.go`
|
||||||
|
- `management/server/idp/okta_test.go`
|
||||||
|
- `management/server/idp/google_workspace.go`
|
||||||
|
- `management/server/idp/google_workspace_test.go`
|
||||||
|
- `management/server/idp/jumpcloud.go`
|
||||||
|
- `management/server/idp/jumpcloud_test.go`
|
||||||
|
- `management/server/idp/authentik.go`
|
||||||
|
- `management/server/idp/authentik_test.go`
|
||||||
|
- `management/server/idp/pocketid.go`
|
||||||
|
- `management/server/idp/pocketid_test.go`
|
||||||
|
|
||||||
|
### Files to Modify
|
||||||
|
|
||||||
|
#### `management/server/idp/idp.go`
|
||||||
|
Simplify `NewManager()` to only support Zitadel:
|
||||||
|
```go
|
||||||
|
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
|
||||||
|
switch strings.ToLower(config.ManagerType) {
|
||||||
|
case "none", "":
|
||||||
|
return nil, nil
|
||||||
|
case "zitadel":
|
||||||
|
return NewZitadelManager(...)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported IdP manager type: %s (only 'zitadel' is supported)", config.ManagerType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `management/server/idp/idp.go`
|
||||||
|
Remove unused config structs:
|
||||||
|
- `Auth0ClientConfig`
|
||||||
|
- `AzureClientConfig`
|
||||||
|
- `KeycloakClientConfig`
|
||||||
|
- etc.
|
||||||
|
|
||||||
|
### Deliverable
|
||||||
|
Only Zitadel IdP manager remains. Build passes. Tests pass.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 4: External IdP Connector API
|
||||||
|
|
||||||
|
**Goal**: API for users to add external IdPs (Okta, Google, LDAP) as Zitadel connectors.
|
||||||
|
|
||||||
|
### New Files
|
||||||
|
|
||||||
|
#### `management/server/idp/connectors.go`
|
||||||
|
Wrapper for Zitadel IdP connector API:
|
||||||
|
```go
|
||||||
|
type Connector struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
Type string // "oidc", "ldap", "saml"
|
||||||
|
Issuer string // for OIDC
|
||||||
|
ClientID string // for OIDC
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConnectorManager interface {
|
||||||
|
AddOIDCConnector(ctx, name, issuer, clientID, clientSecret string, scopes []string) (*Connector, error)
|
||||||
|
AddLDAPConnector(ctx, name, host string, port int, baseDN, bindDN, bindPassword string) (*Connector, error)
|
||||||
|
ListConnectors(ctx) ([]Connector, error)
|
||||||
|
DeleteConnector(ctx, connectorID string) error
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `management/server/http/handlers/idp_connectors_handler.go`
|
||||||
|
REST API handlers:
|
||||||
|
```
|
||||||
|
POST /api/idp/connectors - Add connector
|
||||||
|
GET /api/idp/connectors - List connectors
|
||||||
|
DELETE /api/idp/connectors/{id} - Delete connector
|
||||||
|
```
|
||||||
|
|
||||||
|
### Zitadel API Endpoints Used
|
||||||
|
- `POST /management/v1/idps/generic_oidc` - Add OIDC connector
|
||||||
|
- `POST /management/v1/idps/ldap` - Add LDAP connector
|
||||||
|
- `GET /management/v1/idps` - List all IdPs
|
||||||
|
- `DELETE /management/v1/idps/{id}` - Remove IdP
|
||||||
|
|
||||||
|
### Deliverable
|
||||||
|
Admin users can add/remove external IdP connectors via NetBird API.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Stage 5: User Role Permissions
|
||||||
|
|
||||||
|
**Goal**: Enforce admin/user permissions throughout the application.
|
||||||
|
|
||||||
|
### Files to Modify
|
||||||
|
|
||||||
|
#### `management/server/types/user.go`
|
||||||
|
Simplify roles to:
|
||||||
|
```go
|
||||||
|
const (
|
||||||
|
UserRoleAdmin UserRole = "admin"
|
||||||
|
UserRoleUser UserRole = "user"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Keep `owner` as alias for `admin` for backwards compatibility.
|
||||||
|
|
||||||
|
#### Permission checks (various files)
|
||||||
|
Update permission checks to use simplified role model:
|
||||||
|
- Admin: Full access to all resources
|
||||||
|
- User: View/manage own peers only, no user/policy management
|
||||||
|
|
||||||
|
### Permission Matrix
|
||||||
|
| Resource | Admin | User |
|
||||||
|
|----------|-------|------|
|
||||||
|
| All peers | read/write | - |
|
||||||
|
| Own peers | read/write | read/write |
|
||||||
|
| Users | read/write | - |
|
||||||
|
| Groups | read/write | read (own) |
|
||||||
|
| Policies | read/write | - |
|
||||||
|
| IdP Connectors | read/write | - |
|
||||||
|
| Account settings | read/write | - |
|
||||||
|
|
||||||
|
### Deliverable
|
||||||
|
Role-based access control enforced. User role has limited dashboard access.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture Diagram
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────────┐
|
||||||
|
│ NETBIRD DEPLOYMENT │
|
||||||
|
│ │
|
||||||
|
│ ┌───────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ CADDY (Reverse Proxy) │ │
|
||||||
|
│ │ :80/:443 → routes to appropriate service │ │
|
||||||
|
│ └───────────────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │ │ │ │
|
||||||
|
│ ▼ ▼ ▼ ▼ │
|
||||||
|
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────────┐ │
|
||||||
|
│ │ Dashboard│ │Management│ │ Signal │ │ Zitadel │ │
|
||||||
|
│ │ :80 │ │ :80 │ │ :80 │ │ :8080 (SQLite) │ │
|
||||||
|
│ └──────────┘ └────┬─────┘ └──────────┘ └──────────────────┘ │
|
||||||
|
│ │ │ │
|
||||||
|
│ │ OIDC + Mgmt API │ │
|
||||||
|
│ └──────────────────────────┘ │
|
||||||
|
│ │
|
||||||
|
│ ┌───────────────────────────────────────────────────────────┐ │
|
||||||
|
│ │ Relay │ Coturn (TURN) │ │
|
||||||
|
│ └───────────────────────────────────────────────────────────┘ │
|
||||||
|
└─────────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation Order
|
||||||
|
|
||||||
|
1. **Stage 1** - Infrastructure (can be tested standalone)
|
||||||
|
2. **Stage 2** - Role management (requires Stage 1)
|
||||||
|
3. **Stage 3** - Remove legacy IdPs (can be done in parallel with Stage 2)
|
||||||
|
4. **Stage 4** - Connector API (requires Stage 1)
|
||||||
|
5. **Stage 5** - Permissions (requires Stage 2)
|
||||||
|
|
||||||
|
Stages 3 and 4 can be worked on in parallel after Stage 1 is complete.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Security Considerations
|
||||||
|
|
||||||
|
### First-Boot Credentials
|
||||||
|
- Generated with `openssl rand -base64 32` in configure.sh
|
||||||
|
- Printed to stdout during configuration (save them!)
|
||||||
|
- Marked as "must change on first login" in Zitadel via `PASSWORDCHANGEREQUIRED=true`
|
||||||
|
|
||||||
|
### Zitadel Masterkey
|
||||||
|
- Auto-generated if not provided (32 bytes)
|
||||||
|
- Stored in environment variable (docker secret in production)
|
||||||
|
- Used for Zitadel's internal encryption
|
||||||
|
|
||||||
|
### Service Account
|
||||||
|
- Machine user created via `ZITADEL_FIRSTINSTANCE_ORG_MACHINE_*` env vars
|
||||||
|
- PAT (Personal Access Token) generated with 1-year expiration
|
||||||
|
- Used by NetBird management to call Zitadel API
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Migration Notes
|
||||||
|
|
||||||
|
Existing deployments using Auth0/Azure/Keycloak will need to:
|
||||||
|
1. Deploy Zitadel alongside existing IdP
|
||||||
|
2. Add existing IdP as Zitadel connector
|
||||||
|
3. Migrate users (or let them re-authenticate via connector)
|
||||||
|
4. Switch NetBird config to use Zitadel
|
||||||
|
5. Remove old IdP
|
||||||
|
|
||||||
|
A migration guide should be provided as documentation.
|
||||||
83
infrastructure_files/Caddyfile.tmpl
Normal file
83
infrastructure_files/Caddyfile.tmpl
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
{
|
||||||
|
servers :80,:443 {
|
||||||
|
protocols h1 h2c h2 h3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(security_headers) {
|
||||||
|
header * {
|
||||||
|
# HSTS - use 1 hour for testing, increase to 63072000 (2 years) in production
|
||||||
|
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
|
||||||
|
# Prevent MIME type sniffing
|
||||||
|
X-Content-Type-Options "nosniff"
|
||||||
|
# Clickjacking protection
|
||||||
|
X-Frame-Options "SAMEORIGIN"
|
||||||
|
# XSS protection
|
||||||
|
X-XSS-Protection "1; mode=block"
|
||||||
|
# Remove server header
|
||||||
|
-Server
|
||||||
|
# Referrer policy
|
||||||
|
Referrer-Policy strict-origin-when-cross-origin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
:${NETBIRD_CADDY_PORT}${CADDY_SECURE_DOMAIN} {
|
||||||
|
import security_headers
|
||||||
|
|
||||||
|
# Relay
|
||||||
|
reverse_proxy /relay* relay:${NETBIRD_RELAY_INTERNAL_PORT}
|
||||||
|
|
||||||
|
# Signal - WebSocket proxy
|
||||||
|
reverse_proxy /ws-proxy/signal* signal:80
|
||||||
|
# Signal - gRPC
|
||||||
|
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
|
||||||
|
|
||||||
|
# Management - REST API
|
||||||
|
reverse_proxy /api/* management:80
|
||||||
|
# Management - WebSocket proxy
|
||||||
|
reverse_proxy /ws-proxy/management* management:80
|
||||||
|
# Management - gRPC
|
||||||
|
reverse_proxy /management.ManagementService/* h2c://management:80
|
||||||
|
|
||||||
|
# Zitadel - Admin API
|
||||||
|
reverse_proxy /zitadel.admin.v1.AdminService/* h2c://zitadel:8080
|
||||||
|
reverse_proxy /admin/v1/* h2c://zitadel:8080
|
||||||
|
|
||||||
|
# Zitadel - Auth API
|
||||||
|
reverse_proxy /zitadel.auth.v1.AuthService/* h2c://zitadel:8080
|
||||||
|
reverse_proxy /auth/v1/* h2c://zitadel:8080
|
||||||
|
|
||||||
|
# Zitadel - Management API
|
||||||
|
reverse_proxy /zitadel.management.v1.ManagementService/* h2c://zitadel:8080
|
||||||
|
reverse_proxy /management/v1/* h2c://zitadel:8080
|
||||||
|
|
||||||
|
# Zitadel - System API
|
||||||
|
reverse_proxy /zitadel.system.v1.SystemService/* h2c://zitadel:8080
|
||||||
|
reverse_proxy /system/v1/* h2c://zitadel:8080
|
||||||
|
|
||||||
|
# Zitadel - User API v2
|
||||||
|
reverse_proxy /zitadel.user.v2.UserService/* h2c://zitadel:8080
|
||||||
|
|
||||||
|
# Zitadel - Assets
|
||||||
|
reverse_proxy /assets/v1/* h2c://zitadel:8080
|
||||||
|
|
||||||
|
# Zitadel - UI (login, console, etc.)
|
||||||
|
reverse_proxy /ui/* h2c://zitadel:8080
|
||||||
|
|
||||||
|
# Zitadel - OIDC endpoints
|
||||||
|
reverse_proxy /oidc/v1/* h2c://zitadel:8080
|
||||||
|
reverse_proxy /oauth/v2/* h2c://zitadel:8080
|
||||||
|
reverse_proxy /.well-known/openid-configuration h2c://zitadel:8080
|
||||||
|
|
||||||
|
# Zitadel - SAML
|
||||||
|
reverse_proxy /saml/v2/* h2c://zitadel:8080
|
||||||
|
|
||||||
|
# Zitadel - Other
|
||||||
|
reverse_proxy /openapi/* h2c://zitadel:8080
|
||||||
|
reverse_proxy /debug/* h2c://zitadel:8080
|
||||||
|
reverse_proxy /device/* h2c://zitadel:8080
|
||||||
|
reverse_proxy /device h2c://zitadel:8080
|
||||||
|
|
||||||
|
# Dashboard - catch-all for frontend
|
||||||
|
reverse_proxy /* dashboard:80
|
||||||
|
}
|
||||||
@@ -2,29 +2,20 @@
|
|||||||
|
|
||||||
# Management API
|
# Management API
|
||||||
|
|
||||||
# Management API port
|
# Management API endpoint address, used by the Dashboard (Caddy handles TLS)
|
||||||
NETBIRD_MGMT_API_PORT=${NETBIRD_MGMT_API_PORT:-33073}
|
NETBIRD_MGMT_API_ENDPOINT=${NETBIRD_HTTP_PROTOCOL:-https}://$NETBIRD_DOMAIN
|
||||||
# Management API endpoint address, used by the Dashboard
|
|
||||||
NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
|
|
||||||
# Management Certificate file path. These are generated by the Dashboard container
|
|
||||||
NETBIRD_LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN
|
|
||||||
NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/fullchain.pem"
|
|
||||||
# Management Certificate key file path.
|
|
||||||
NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/privkey.pem"
|
|
||||||
# By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain
|
# By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain
|
||||||
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
|
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
|
||||||
NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted}
|
NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted}
|
||||||
NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=${NETBIRD_MGMT_IDP_SIGNKEY_REFRESH:-false}
|
|
||||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=${NETBIRD_MGMT_DISABLE_DEFAULT_POLICY:-false}
|
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=${NETBIRD_MGMT_DISABLE_DEFAULT_POLICY:-false}
|
||||||
|
|
||||||
# Signal
|
# Signal
|
||||||
NETBIRD_SIGNAL_PROTOCOL="http"
|
NETBIRD_SIGNAL_PROTOCOL=${NETBIRD_HTTP_PROTOCOL:-https}
|
||||||
NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-10000}
|
NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-443}
|
||||||
|
|
||||||
# Relay
|
# Relay (internal port for Caddy reverse proxy)
|
||||||
NETBIRD_RELAY_DOMAIN=${NETBIRD_RELAY_DOMAIN:-$NETBIRD_DOMAIN}
|
NETBIRD_RELAY_INTERNAL_PORT=${NETBIRD_RELAY_INTERNAL_PORT:-80}
|
||||||
NETBIRD_RELAY_PORT=${NETBIRD_RELAY_PORT:-33080}
|
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT:-${NETBIRD_RELAY_PROTO:-rels}://$NETBIRD_DOMAIN:${NETBIRD_RELAY_PORT:-443}}
|
||||||
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT:-rel://$NETBIRD_RELAY_DOMAIN:$NETBIRD_RELAY_PORT}
|
|
||||||
# Relay auth secret
|
# Relay auth secret
|
||||||
NETBIRD_RELAY_AUTH_SECRET=
|
NETBIRD_RELAY_AUTH_SECRET=
|
||||||
|
|
||||||
@@ -141,3 +132,57 @@ export NETBIRD_RELAY_ENDPOINT
|
|||||||
export NETBIRD_RELAY_AUTH_SECRET
|
export NETBIRD_RELAY_AUTH_SECRET
|
||||||
export NETBIRD_RELAY_TAG
|
export NETBIRD_RELAY_TAG
|
||||||
export NETBIRD_MGMT_DISABLE_DEFAULT_POLICY
|
export NETBIRD_MGMT_DISABLE_DEFAULT_POLICY
|
||||||
|
|
||||||
|
# Zitadel IdP Configuration
|
||||||
|
ZITADEL_TAG=${ZITADEL_TAG:-"v4.7.6"}
|
||||||
|
# Zitadel masterkey (32 bytes, auto-generated if not set)
|
||||||
|
ZITADEL_MASTERKEY=
|
||||||
|
# Zitadel admin credentials (auto-generated if not set)
|
||||||
|
ZITADEL_ADMIN_USERNAME=
|
||||||
|
ZITADEL_ADMIN_PASSWORD=
|
||||||
|
# Zitadel external configuration
|
||||||
|
ZITADEL_EXTERNALSECURE=${ZITADEL_EXTERNALSECURE:-true}
|
||||||
|
ZITADEL_EXTERNALPORT=${ZITADEL_EXTERNALPORT:-443}
|
||||||
|
ZITADEL_TLS_MODE=${ZITADEL_TLS_MODE:-external}
|
||||||
|
# Zitadel PAT expiration (1 year from startup)
|
||||||
|
ZITADEL_PAT_EXPIRATION=
|
||||||
|
# Zitadel management endpoint
|
||||||
|
ZITADEL_MANAGEMENT_ENDPOINT=${NETBIRD_HTTP_PROTOCOL:-https}://$NETBIRD_DOMAIN/management/v1
|
||||||
|
# HTTP protocol (http or https)
|
||||||
|
NETBIRD_HTTP_PROTOCOL=${NETBIRD_HTTP_PROTOCOL:-https}
|
||||||
|
# Caddy configuration
|
||||||
|
NETBIRD_CADDY_PORT=${NETBIRD_CADDY_PORT:-80}
|
||||||
|
CADDY_SECURE_DOMAIN=
|
||||||
|
|
||||||
|
# Zitadel OIDC endpoints
|
||||||
|
NETBIRD_AUTH_AUTHORITY=${NETBIRD_HTTP_PROTOCOL:-https}://$NETBIRD_DOMAIN
|
||||||
|
NETBIRD_AUTH_TOKEN_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/oauth/v2/token
|
||||||
|
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/.well-known/openid-configuration
|
||||||
|
NETBIRD_AUTH_JWT_CERTS=${NETBIRD_AUTH_AUTHORITY}/.well-known/jwks.json
|
||||||
|
NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/oauth/v2/authorize
|
||||||
|
NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=${NETBIRD_AUTH_AUTHORITY}/oauth/v2/device_authorization
|
||||||
|
NETBIRD_AUTH_USER_ID_CLAIM=${NETBIRD_AUTH_USER_ID_CLAIM:-sub}
|
||||||
|
NETBIRD_AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES:-"openid profile email offline_access"}
|
||||||
|
|
||||||
|
# Zitadel exports
|
||||||
|
export ZITADEL_TAG
|
||||||
|
export ZITADEL_MASTERKEY
|
||||||
|
export ZITADEL_ADMIN_USERNAME
|
||||||
|
export ZITADEL_ADMIN_PASSWORD
|
||||||
|
export ZITADEL_EXTERNALSECURE
|
||||||
|
export ZITADEL_EXTERNALPORT
|
||||||
|
export ZITADEL_TLS_MODE
|
||||||
|
export ZITADEL_PAT_EXPIRATION
|
||||||
|
export ZITADEL_MANAGEMENT_ENDPOINT
|
||||||
|
export NETBIRD_HTTP_PROTOCOL
|
||||||
|
export NETBIRD_CADDY_PORT
|
||||||
|
export CADDY_SECURE_DOMAIN
|
||||||
|
export NETBIRD_AUTH_AUTHORITY
|
||||||
|
export NETBIRD_AUTH_TOKEN_ENDPOINT
|
||||||
|
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT
|
||||||
|
export NETBIRD_AUTH_JWT_CERTS
|
||||||
|
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
||||||
|
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT
|
||||||
|
export NETBIRD_AUTH_USER_ID_CLAIM
|
||||||
|
export NETBIRD_AUTH_SUPPORTED_SCOPES
|
||||||
|
export NETBIRD_RELAY_INTERNAL_PORT
|
||||||
|
|||||||
@@ -1,261 +1,137 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
if ! which curl >/dev/null 2>&1; then
|
# Check required dependencies
|
||||||
echo "This script uses curl fetch OpenID configuration from IDP."
|
for cmd in curl jq envsubst openssl; do
|
||||||
echo "Please install curl and re-run the script https://curl.se/"
|
if ! which $cmd >/dev/null 2>&1; then
|
||||||
echo ""
|
echo "This script requires $cmd. Please install it and re-run."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
done
|
||||||
if ! which jq >/dev/null 2>&1; then
|
|
||||||
echo "This script uses jq to load OpenID configuration from IDP."
|
|
||||||
echo "Please install jq and re-run the script https://stedolan.github.io/jq/"
|
|
||||||
echo ""
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
# Source configuration
|
||||||
source setup.env
|
source setup.env
|
||||||
source base.setup.env
|
source base.setup.env
|
||||||
|
|
||||||
if ! which envsubst >/dev/null 2>&1; then
|
# Validate required variables
|
||||||
echo "envsubst is needed to run this script"
|
if [[ -z "$NETBIRD_DOMAIN" ]]; then
|
||||||
if [[ $(uname) == "Darwin" ]]; then
|
echo "NETBIRD_DOMAIN is not set, please update your setup.env file"
|
||||||
echo "you can install it with homebrew (https://brew.sh):"
|
|
||||||
echo "brew install gettext"
|
|
||||||
else
|
|
||||||
if which apt-get >/dev/null 2>&1; then
|
|
||||||
echo "you can install it by running"
|
|
||||||
echo "apt-get update && apt-get install gettext-base"
|
|
||||||
else
|
|
||||||
echo "you can install it by installing the package gettext with your package manager"
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ "x-$NETBIRD_DOMAIN" == "x-" ]]; then
|
# Check database configuration if using external database
|
||||||
echo NETBIRD_DOMAIN is not set, please update your setup.env file
|
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "postgres" && -z "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" ]]; then
|
||||||
echo If you are migrating from old versions, you might need to update your variables prefixes from
|
echo "Error: NETBIRD_STORE_CONFIG_ENGINE=postgres but NETBIRD_STORE_ENGINE_POSTGRES_DSN is not set."
|
||||||
echo WIRETRUSTEE_.. TO NETBIRD_
|
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Check if PostgreSQL is set as the store engine
|
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "mysql" && -z "$NETBIRD_STORE_ENGINE_MYSQL_DSN" ]]; then
|
||||||
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "postgres" ]]; then
|
echo "Error: NETBIRD_STORE_CONFIG_ENGINE=mysql but NETBIRD_STORE_ENGINE_MYSQL_DSN is not set."
|
||||||
# Exit if 'NETBIRD_STORE_ENGINE_POSTGRES_DSN' is not set
|
exit 1
|
||||||
if [[ -z "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" ]]; then
|
|
||||||
echo "Warning: NETBIRD_STORE_CONFIG_ENGINE=postgres but NETBIRD_STORE_ENGINE_POSTGRES_DSN is not set."
|
|
||||||
echo "Please add the following line to your setup.env file:"
|
|
||||||
echo 'NETBIRD_STORE_ENGINE_POSTGRES_DSN="host=<PG_HOST> user=<PG_USER> password=<PG_PASSWORD> dbname=<PG_DB_NAME> port=<PG_PORT>"'
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
export NETBIRD_STORE_ENGINE_POSTGRES_DSN
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Check if MySQL is set as the store engine
|
# Configure for local development vs production
|
||||||
if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "mysql" ]]; then
|
|
||||||
# Exit if 'NETBIRD_STORE_ENGINE_MYSQL_DSN' is not set
|
|
||||||
if [[ -z "$NETBIRD_STORE_ENGINE_MYSQL_DSN" ]]; then
|
|
||||||
echo "Warning: NETBIRD_STORE_CONFIG_ENGINE=mysql but NETBIRD_STORE_ENGINE_MYSQL_DSN is not set."
|
|
||||||
echo "Please add the following line to your setup.env file:"
|
|
||||||
echo 'NETBIRD_STORE_ENGINE_MYSQL_DSN="<username>:<password>@tcp(127.0.0.1:3306)/<database>"'
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
export NETBIRD_STORE_ENGINE_MYSQL_DSN
|
|
||||||
fi
|
|
||||||
|
|
||||||
# local development or tests
|
|
||||||
if [[ $NETBIRD_DOMAIN == "localhost" || $NETBIRD_DOMAIN == "127.0.0.1" ]]; then
|
if [[ $NETBIRD_DOMAIN == "localhost" || $NETBIRD_DOMAIN == "127.0.0.1" ]]; then
|
||||||
export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN="netbird.selfhosted"
|
export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN="netbird.selfhosted"
|
||||||
export NETBIRD_MGMT_API_ENDPOINT=http://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
|
export NETBIRD_MGMT_API_ENDPOINT="http://$NETBIRD_DOMAIN"
|
||||||
unset NETBIRD_MGMT_API_CERT_FILE
|
export NETBIRD_HTTP_PROTOCOL="http"
|
||||||
unset NETBIRD_MGMT_API_CERT_KEY_FILE
|
export ZITADEL_EXTERNALSECURE="false"
|
||||||
fi
|
export ZITADEL_EXTERNALPORT="80"
|
||||||
|
export ZITADEL_TLS_MODE="disabled"
|
||||||
# if not provided, we generate a turn password
|
export NETBIRD_RELAY_PROTO="rel"
|
||||||
if [[ "x-$TURN_PASSWORD" == "x-" ]]; then
|
|
||||||
export TURN_PASSWORD=$(openssl rand -base64 32 | sed 's/=//g')
|
|
||||||
fi
|
|
||||||
|
|
||||||
TURN_EXTERNAL_IP_CONFIG="#"
|
|
||||||
|
|
||||||
if [[ "x-$NETBIRD_TURN_EXTERNAL_IP" == "x-" ]]; then
|
|
||||||
echo "discovering server's public IP"
|
|
||||||
IP=$(curl -s -4 https://jsonip.com | jq -r '.ip')
|
|
||||||
if [[ "x-$IP" != "x-" ]]; then
|
|
||||||
TURN_EXTERNAL_IP_CONFIG="external-ip=$IP"
|
|
||||||
else
|
|
||||||
echo "unable to discover server's public IP"
|
|
||||||
fi
|
|
||||||
else
|
else
|
||||||
echo "${NETBIRD_TURN_EXTERNAL_IP}"| egrep '([0-9]{1,3}\.){3}[0-9]{1,3}$' > /dev/null
|
export NETBIRD_HTTP_PROTOCOL="https"
|
||||||
if [[ $? -eq 0 ]]; then
|
export ZITADEL_EXTERNALSECURE="true"
|
||||||
echo "using provided server's public IP"
|
export ZITADEL_EXTERNALPORT="443"
|
||||||
TURN_EXTERNAL_IP_CONFIG="external-ip=$NETBIRD_TURN_EXTERNAL_IP"
|
export ZITADEL_TLS_MODE="external"
|
||||||
else
|
export NETBIRD_RELAY_PROTO="rels"
|
||||||
echo "provided NETBIRD_TURN_EXTERNAL_IP $NETBIRD_TURN_EXTERNAL_IP is invalid, please correct it and try again"
|
export CADDY_SECURE_DOMAIN=", $NETBIRD_DOMAIN:443"
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Auto-generate secrets if not provided
|
||||||
|
[[ -z "$TURN_PASSWORD" ]] && export TURN_PASSWORD=$(openssl rand -base64 32 | sed 's/=//g')
|
||||||
|
[[ -z "$NETBIRD_RELAY_AUTH_SECRET" ]] && export NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed 's/=//g')
|
||||||
|
[[ -z "$ZITADEL_MASTERKEY" ]] && export ZITADEL_MASTERKEY=$(openssl rand -base64 32 | head -c 32)
|
||||||
|
|
||||||
|
# Generate Zitadel admin credentials if not provided
|
||||||
|
if [[ -z "$ZITADEL_ADMIN_USERNAME" ]]; then
|
||||||
|
export ZITADEL_ADMIN_USERNAME="admin@${NETBIRD_DOMAIN}"
|
||||||
|
fi
|
||||||
|
if [[ -z "$ZITADEL_ADMIN_PASSWORD" ]]; then
|
||||||
|
export ZITADEL_ADMIN_PASSWORD="$(openssl rand -base64 32 | sed 's/=//g')!"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set Zitadel PAT expiration (1 year from now)
|
||||||
|
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||||
|
export ZITADEL_PAT_EXPIRATION=$(date -u -v+1y "+%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
else
|
||||||
|
export ZITADEL_PAT_EXPIRATION=$(date -u -d "+1 year" "+%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Discover external IP for TURN
|
||||||
|
TURN_EXTERNAL_IP_CONFIG="#"
|
||||||
|
if [[ -z "$NETBIRD_TURN_EXTERNAL_IP" ]]; then
|
||||||
|
IP=$(curl -s -4 https://jsonip.com | jq -r '.ip' 2>/dev/null || echo "")
|
||||||
|
[[ -n "$IP" ]] && TURN_EXTERNAL_IP_CONFIG="external-ip=$IP"
|
||||||
|
elif echo "$NETBIRD_TURN_EXTERNAL_IP" | grep -qE '^([0-9]{1,3}\.){3}[0-9]{1,3}$'; then
|
||||||
|
TURN_EXTERNAL_IP_CONFIG="external-ip=$NETBIRD_TURN_EXTERNAL_IP"
|
||||||
|
fi
|
||||||
export TURN_EXTERNAL_IP_CONFIG
|
export TURN_EXTERNAL_IP_CONFIG
|
||||||
|
|
||||||
# if not provided, we generate a relay auth secret
|
# Configure endpoints
|
||||||
if [[ "x-$NETBIRD_RELAY_AUTH_SECRET" == "x-" ]]; then
|
export NETBIRD_AUTH_AUTHORITY="${NETBIRD_HTTP_PROTOCOL}://${NETBIRD_DOMAIN}"
|
||||||
export NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed 's/=//g')
|
export NETBIRD_AUTH_TOKEN_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/oauth/v2/token"
|
||||||
fi
|
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/.well-known/openid-configuration"
|
||||||
|
export NETBIRD_AUTH_JWT_CERTS="${NETBIRD_AUTH_AUTHORITY}/.well-known/jwks.json"
|
||||||
artifacts_path="./artifacts"
|
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/oauth/v2/authorize"
|
||||||
mkdir -p $artifacts_path
|
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/oauth/v2/device_authorization"
|
||||||
|
export ZITADEL_MANAGEMENT_ENDPOINT="${NETBIRD_AUTH_AUTHORITY}/management/v1"
|
||||||
|
export NETBIRD_RELAY_ENDPOINT="${NETBIRD_RELAY_PROTO}://${NETBIRD_DOMAIN}:${ZITADEL_EXTERNALPORT}"
|
||||||
|
|
||||||
|
# Volume names (with backwards compatibility)
|
||||||
MGMT_VOLUMENAME="${VOLUME_PREFIX}${MGMT_VOLUMESUFFIX}"
|
MGMT_VOLUMENAME="${VOLUME_PREFIX}${MGMT_VOLUMESUFFIX}"
|
||||||
SIGNAL_VOLUMENAME="${VOLUME_PREFIX}${SIGNAL_VOLUMESUFFIX}"
|
SIGNAL_VOLUMENAME="${VOLUME_PREFIX}${SIGNAL_VOLUMESUFFIX}"
|
||||||
LETSENCRYPT_VOLUMENAME="${VOLUME_PREFIX}${LETSENCRYPT_VOLUMESUFFIX}"
|
|
||||||
# if volume with wiretrustee- prefix already exists, use it, else create new with netbird-
|
|
||||||
OLD_PREFIX='wiretrustee-'
|
OLD_PREFIX='wiretrustee-'
|
||||||
if docker volume ls | grep -q "${OLD_PREFIX}${MGMT_VOLUMESUFFIX}"; then
|
docker volume ls 2>/dev/null | grep -q "${OLD_PREFIX}${MGMT_VOLUMESUFFIX}" && MGMT_VOLUMENAME="${OLD_PREFIX}${MGMT_VOLUMESUFFIX}"
|
||||||
MGMT_VOLUMENAME="${OLD_PREFIX}${MGMT_VOLUMESUFFIX}"
|
docker volume ls 2>/dev/null | grep -q "${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}" && SIGNAL_VOLUMENAME="${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}"
|
||||||
fi
|
export MGMT_VOLUMENAME SIGNAL_VOLUMENAME
|
||||||
if docker volume ls | grep -q "${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}"; then
|
|
||||||
SIGNAL_VOLUMENAME="${OLD_PREFIX}${SIGNAL_VOLUMESUFFIX}"
|
# Preserve existing encryption key
|
||||||
fi
|
if test -f 'management.json'; then
|
||||||
if docker volume ls | grep -q "${OLD_PREFIX}${LETSENCRYPT_VOLUMESUFFIX}"; then
|
encKey=$(jq -r ".DataStoreEncryptionKey" management.json 2>/dev/null || echo "null")
|
||||||
LETSENCRYPT_VOLUMENAME="${OLD_PREFIX}${LETSENCRYPT_VOLUMESUFFIX}"
|
[[ "$encKey" != "null" && -n "$encKey" ]] && export NETBIRD_DATASTORE_ENC_KEY="$encKey"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
export MGMT_VOLUMENAME
|
# Create artifacts directory and backup existing files
|
||||||
export SIGNAL_VOLUMENAME
|
artifacts_path="./artifacts"
|
||||||
export LETSENCRYPT_VOLUMENAME
|
mkdir -p "$artifacts_path"
|
||||||
|
bkp_postfix="$(date +%s)"
|
||||||
#backwards compatibility after migrating to generic OIDC with Auth0
|
for file in docker-compose.yml management.json turnserver.conf Caddyfile; do
|
||||||
if [[ -z "${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT}" ]]; then
|
[[ -f "${artifacts_path}/${file}" ]] && cp "${artifacts_path}/${file}" "${artifacts_path}/${file}.bkp.${bkp_postfix}"
|
||||||
|
|
||||||
if [[ -z "${NETBIRD_AUTH0_DOMAIN}" ]]; then
|
|
||||||
# not a backward compatible state
|
|
||||||
echo "NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT property must be set in the setup.env file"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "It seems like you provided an old setup.env file."
|
|
||||||
echo "Since the release of v0.8.10, we introduced a new set of properties."
|
|
||||||
echo "The script is backward compatible and will continue automatically."
|
|
||||||
echo "In the future versions it will be deprecated. Please refer to the documentation to learn about the changes http://netbird.io/docs/getting-started/self-hosting"
|
|
||||||
|
|
||||||
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="https://${NETBIRD_AUTH0_DOMAIN}/.well-known/openid-configuration"
|
|
||||||
export NETBIRD_USE_AUTH0="true"
|
|
||||||
export NETBIRD_AUTH_AUDIENCE=${NETBIRD_AUTH0_AUDIENCE}
|
|
||||||
export NETBIRD_AUTH_CLIENT_ID=${NETBIRD_AUTH0_CLIENT_ID}
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "loading OpenID configuration from ${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT} to the openid-configuration.json file"
|
|
||||||
curl "${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT}" -q -o ${artifacts_path}/openid-configuration.json
|
|
||||||
|
|
||||||
export NETBIRD_AUTH_AUTHORITY=$(jq -r '.issuer' ${artifacts_path}/openid-configuration.json)
|
|
||||||
export NETBIRD_AUTH_JWT_CERTS=$(jq -r '.jwks_uri' ${artifacts_path}/openid-configuration.json)
|
|
||||||
export NETBIRD_AUTH_TOKEN_ENDPOINT=$(jq -r '.token_endpoint' ${artifacts_path}/openid-configuration.json)
|
|
||||||
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=$(jq -r '.device_authorization_endpoint' ${artifacts_path}/openid-configuration.json)
|
|
||||||
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT=$(jq -r '.authorization_endpoint' ${artifacts_path}/openid-configuration.json)
|
|
||||||
|
|
||||||
if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then
|
|
||||||
# user enabled Device Authorization Grant feature
|
|
||||||
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ "$NETBIRD_TOKEN_SOURCE" = "idToken" ]; then
|
|
||||||
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN=true
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Check if letsencrypt was disabled
|
|
||||||
if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then
|
|
||||||
export NETBIRD_DASHBOARD_ENDPOINT="https://$NETBIRD_DOMAIN:443"
|
|
||||||
export NETBIRD_SIGNAL_ENDPOINT="https://$NETBIRD_DOMAIN:$NETBIRD_SIGNAL_PORT"
|
|
||||||
export NETBIRD_RELAY_ENDPOINT="rels://$NETBIRD_DOMAIN:$NETBIRD_RELAY_PORT/relay"
|
|
||||||
|
|
||||||
echo "Letsencrypt was disabled, the Https-endpoints cannot be used anymore"
|
|
||||||
echo " and a reverse-proxy with Https needs to be placed in front of netbird!"
|
|
||||||
echo "The following forwards have to be setup:"
|
|
||||||
echo "- $NETBIRD_DASHBOARD_ENDPOINT -http-> dashboard:80"
|
|
||||||
echo "- $NETBIRD_MGMT_API_ENDPOINT/api -http-> management:$NETBIRD_MGMT_API_PORT"
|
|
||||||
echo "- $NETBIRD_MGMT_API_ENDPOINT/management.ManagementService/ -grpc-> management:$NETBIRD_MGMT_API_PORT"
|
|
||||||
echo "- $NETBIRD_SIGNAL_ENDPOINT/signalexchange.SignalExchange/ -grpc-> signal:80"
|
|
||||||
echo "- $NETBIRD_RELAY_ENDPOINT/ -http-> relay:33080"
|
|
||||||
echo "You most likely also have to change NETBIRD_MGMT_API_ENDPOINT in base.setup.env and port-mappings in docker-compose.yml.tmpl and rerun this script."
|
|
||||||
echo " The target of the forwards depends on your setup. Beware of the gRPC protocol instead of http for management and signal!"
|
|
||||||
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
unset NETBIRD_LETSENCRYPT_DOMAIN
|
|
||||||
unset NETBIRD_MGMT_API_CERT_FILE
|
|
||||||
unset NETBIRD_MGMT_API_CERT_KEY_FILE
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then
|
|
||||||
export NETBIRD_SIGNAL_PROTOCOL="https"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Check if management identity provider is set
|
|
||||||
if [ -n "$NETBIRD_MGMT_IDP" ]; then
|
|
||||||
EXTRA_CONFIG={}
|
|
||||||
|
|
||||||
# extract extra config from all env prefixed with NETBIRD_IDP_MGMT_EXTRA_
|
|
||||||
for var in ${!NETBIRD_IDP_MGMT_EXTRA_*}; do
|
|
||||||
# convert key snake case to camel case
|
|
||||||
key=$(
|
|
||||||
echo "${var#NETBIRD_IDP_MGMT_EXTRA_}" | awk -F "_" \
|
|
||||||
'{for (i=1; i<=NF; i++) {output=output substr($i,1,1) tolower(substr($i,2))} print output}'
|
|
||||||
)
|
|
||||||
value="${!var}"
|
|
||||||
|
|
||||||
echo "$var"
|
|
||||||
EXTRA_CONFIG=$(jq --arg k "$key" --arg v "$value" '.[$k] = $v' <<<"$EXTRA_CONFIG")
|
|
||||||
done
|
|
||||||
|
|
||||||
export NETBIRD_MGMT_IDP
|
|
||||||
export NETBIRD_IDP_MGMT_CLIENT_ID
|
|
||||||
export NETBIRD_IDP_MGMT_CLIENT_SECRET
|
|
||||||
export NETBIRD_IDP_MGMT_EXTRA_CONFIG=$EXTRA_CONFIG
|
|
||||||
else
|
|
||||||
export NETBIRD_IDP_MGMT_EXTRA_CONFIG={}
|
|
||||||
fi
|
|
||||||
|
|
||||||
IFS=',' read -r -a REDIRECT_URL_PORTS <<< "$NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS"
|
|
||||||
REDIRECT_URLS=""
|
|
||||||
for port in "${REDIRECT_URL_PORTS[@]}"; do
|
|
||||||
REDIRECT_URLS+="\"http://localhost:${port}\","
|
|
||||||
done
|
done
|
||||||
|
|
||||||
export NETBIRD_AUTH_PKCE_REDIRECT_URLS=${REDIRECT_URLS%,}
|
# Generate configuration files
|
||||||
|
envsubst < docker-compose.yml.tmpl > "$artifacts_path/docker-compose.yml"
|
||||||
|
envsubst < management.json.tmpl | jq . > "$artifacts_path/management.json"
|
||||||
|
envsubst < turnserver.conf.tmpl > "$artifacts_path/turnserver.conf"
|
||||||
|
envsubst < Caddyfile.tmpl > "$artifacts_path/Caddyfile"
|
||||||
|
|
||||||
# Remove audience for providers that do not support it
|
# Print summary
|
||||||
if [ "$NETBIRD_DASH_AUTH_USE_AUDIENCE" = "false" ]; then
|
echo ""
|
||||||
export NETBIRD_DASH_AUTH_AUDIENCE=none
|
echo "=========================================="
|
||||||
export NETBIRD_AUTH_PKCE_AUDIENCE=
|
echo " NetBird Configuration Complete"
|
||||||
fi
|
echo "=========================================="
|
||||||
|
echo " Domain: $NETBIRD_DOMAIN"
|
||||||
# Read the encryption key
|
echo " Protocol: $NETBIRD_HTTP_PROTOCOL"
|
||||||
if test -f 'management.json'; then
|
echo " Zitadel: $ZITADEL_TAG (SQLite)"
|
||||||
encKey=$(jq -r ".DataStoreEncryptionKey" management.json)
|
echo "=========================================="
|
||||||
if [[ "$encKey" != "null" ]]; then
|
echo ""
|
||||||
export NETBIRD_DATASTORE_ENC_KEY=$encKey
|
echo " ADMIN CREDENTIALS (save these!):"
|
||||||
|
echo " Username: $ZITADEL_ADMIN_USERNAME"
|
||||||
fi
|
echo " Password: $ZITADEL_ADMIN_PASSWORD"
|
||||||
fi
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
env | grep NETBIRD
|
echo ""
|
||||||
|
echo "To start NetBird:"
|
||||||
bkp_postfix="$(date +%s)"
|
echo " cd $artifacts_path && docker compose up -d"
|
||||||
if test -f "${artifacts_path}/docker-compose.yml"; then
|
echo ""
|
||||||
cp $artifacts_path/docker-compose.yml "${artifacts_path}/docker-compose.yml.bkp.${bkp_postfix}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if test -f "${artifacts_path}/management.json"; then
|
|
||||||
cp $artifacts_path/management.json "${artifacts_path}/management.json.bkp.${bkp_postfix}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if test -f "${artifacts_path}/turnserver.conf"; then
|
|
||||||
cp ${artifacts_path}/turnserver.conf "${artifacts_path}/turnserver.conf.bkp.${bkp_postfix}"
|
|
||||||
fi
|
|
||||||
envsubst <docker-compose.yml.tmpl >$artifacts_path/docker-compose.yml
|
|
||||||
envsubst <management.json.tmpl | jq . >$artifacts_path/management.json
|
|
||||||
envsubst <turnserver.conf.tmpl >$artifacts_path/turnserver.conf
|
|
||||||
|
|||||||
@@ -7,108 +7,137 @@ x-default: &default
|
|||||||
max-file: '2'
|
max-file: '2'
|
||||||
|
|
||||||
services:
|
services:
|
||||||
|
# Caddy reverse proxy
|
||||||
|
caddy:
|
||||||
|
<<: *default
|
||||||
|
image: caddy:2
|
||||||
|
networks: [netbird]
|
||||||
|
ports:
|
||||||
|
- '443:443'
|
||||||
|
- '443:443/udp'
|
||||||
|
- '80:80'
|
||||||
|
volumes:
|
||||||
|
- netbird-caddy-data:/data
|
||||||
|
- ./Caddyfile:/etc/caddy/Caddyfile:ro
|
||||||
|
depends_on:
|
||||||
|
zitadel:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
# UI dashboard
|
# UI dashboard
|
||||||
dashboard:
|
dashboard:
|
||||||
<<: *default
|
<<: *default
|
||||||
image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG
|
image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG
|
||||||
ports:
|
networks: [netbird]
|
||||||
- 80:80
|
|
||||||
- 443:443
|
|
||||||
environment:
|
environment:
|
||||||
# Endpoints
|
|
||||||
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
||||||
- NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
- NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
||||||
# OIDC
|
- AUTH_AUDIENCE=$NETBIRD_AUTH_CLIENT_ID
|
||||||
- AUTH_AUDIENCE=$NETBIRD_DASH_AUTH_AUDIENCE
|
|
||||||
- AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
|
- AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
|
||||||
- AUTH_CLIENT_SECRET=$NETBIRD_AUTH_CLIENT_SECRET
|
- AUTH_CLIENT_SECRET=
|
||||||
- AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY
|
- AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY
|
||||||
- USE_AUTH0=$NETBIRD_USE_AUTH0
|
- USE_AUTH0=false
|
||||||
- AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES
|
- AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES
|
||||||
- AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI
|
- AUTH_REDIRECT_URI=/nb-auth
|
||||||
- AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI
|
- AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
|
||||||
- NETBIRD_TOKEN_SOURCE=$NETBIRD_TOKEN_SOURCE
|
- NETBIRD_TOKEN_SOURCE=$NETBIRD_TOKEN_SOURCE
|
||||||
# SSL
|
|
||||||
- NGINX_SSL_PORT=443
|
- NGINX_SSL_PORT=443
|
||||||
# Letsencrypt
|
- LETSENCRYPT_DOMAIN=none
|
||||||
- LETSENCRYPT_DOMAIN=$NETBIRD_LETSENCRYPT_DOMAIN
|
depends_on:
|
||||||
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
|
zitadel:
|
||||||
volumes:
|
condition: service_healthy
|
||||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
|
|
||||||
|
|
||||||
# Signal
|
# Signal
|
||||||
signal:
|
signal:
|
||||||
<<: *default
|
<<: *default
|
||||||
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
|
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
|
||||||
depends_on:
|
networks: [netbird]
|
||||||
- dashboard
|
|
||||||
volumes:
|
volumes:
|
||||||
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
||||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
|
command: ["--log-file", "console"]
|
||||||
ports:
|
|
||||||
- $NETBIRD_SIGNAL_PORT:80
|
|
||||||
# # port and command for Let's Encrypt validation
|
|
||||||
# - 443:443
|
|
||||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
|
||||||
command: [
|
|
||||||
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
|
|
||||||
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
|
||||||
"--log-file", "console"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Relay
|
# Relay
|
||||||
relay:
|
relay:
|
||||||
<<: *default
|
<<: *default
|
||||||
image: netbirdio/relay:$NETBIRD_RELAY_TAG
|
image: netbirdio/relay:$NETBIRD_RELAY_TAG
|
||||||
|
networks: [netbird]
|
||||||
environment:
|
environment:
|
||||||
- NB_LOG_LEVEL=info
|
- NB_LOG_LEVEL=info
|
||||||
- NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_PORT
|
- NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_INTERNAL_PORT
|
||||||
- NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_ENDPOINT
|
- NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_ENDPOINT
|
||||||
# todo: change to a secure secret
|
- NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
|
||||||
- NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
|
|
||||||
ports:
|
|
||||||
- $NETBIRD_RELAY_PORT:$NETBIRD_RELAY_PORT
|
|
||||||
|
|
||||||
# Management
|
# Management
|
||||||
management:
|
management:
|
||||||
<<: *default
|
<<: *default
|
||||||
image: netbirdio/management:$NETBIRD_MANAGEMENT_TAG
|
image: netbirdio/management:$NETBIRD_MANAGEMENT_TAG
|
||||||
depends_on:
|
networks: [netbird]
|
||||||
- dashboard
|
|
||||||
volumes:
|
volumes:
|
||||||
- $MGMT_VOLUMENAME:/var/lib/netbird
|
- $MGMT_VOLUMENAME:/var/lib/netbird
|
||||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
|
|
||||||
- ./management.json:/etc/netbird/management.json
|
- ./management.json:/etc/netbird/management.json
|
||||||
ports:
|
|
||||||
- $NETBIRD_MGMT_API_PORT:443 #API port
|
|
||||||
# # command for Let's Encrypt validation without dashboard container
|
|
||||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
|
||||||
command: [
|
command: [
|
||||||
"--port", "443",
|
"--port", "80",
|
||||||
"--log-file", "console",
|
"--log-file", "console",
|
||||||
"--log-level", "info",
|
"--log-level", "info",
|
||||||
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
|
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
|
||||||
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
|
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
|
||||||
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
|
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN",
|
||||||
]
|
"--idp-sign-key-refresh-enabled"
|
||||||
|
]
|
||||||
environment:
|
environment:
|
||||||
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN
|
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN
|
||||||
- NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN
|
- NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN
|
||||||
|
depends_on:
|
||||||
|
zitadel:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
# Coturn
|
# Coturn
|
||||||
coturn:
|
coturn:
|
||||||
<<: *default
|
<<: *default
|
||||||
image: coturn/coturn:$COTURN_TAG
|
image: coturn/coturn:$COTURN_TAG
|
||||||
#domainname: $TURN_DOMAIN # only needed when TLS is enabled
|
|
||||||
volumes:
|
volumes:
|
||||||
- ./turnserver.conf:/etc/turnserver.conf:ro
|
- ./turnserver.conf:/etc/turnserver.conf:ro
|
||||||
# - ./privkey.pem:/etc/coturn/private/privkey.pem:ro
|
|
||||||
# - ./cert.pem:/etc/coturn/certs/cert.pem:ro
|
|
||||||
network_mode: host
|
network_mode: host
|
||||||
command:
|
command:
|
||||||
- -c /etc/turnserver.conf
|
- -c /etc/turnserver.conf
|
||||||
|
|
||||||
|
# Zitadel Identity Provider (single container with SQLite)
|
||||||
|
zitadel:
|
||||||
|
<<: *default
|
||||||
|
image: ghcr.io/zitadel/zitadel:$ZITADEL_TAG
|
||||||
|
networks: [netbird]
|
||||||
|
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
||||||
|
environment:
|
||||||
|
- ZITADEL_LOG_LEVEL=info
|
||||||
|
- ZITADEL_MASTERKEY=$ZITADEL_MASTERKEY
|
||||||
|
- ZITADEL_EXTERNALSECURE=$ZITADEL_EXTERNALSECURE
|
||||||
|
- ZITADEL_TLS_ENABLED=false
|
||||||
|
- ZITADEL_EXTERNALPORT=$ZITADEL_EXTERNALPORT
|
||||||
|
- ZITADEL_EXTERNALDOMAIN=$NETBIRD_DOMAIN
|
||||||
|
# SQLite database (no separate PostgreSQL container needed)
|
||||||
|
- ZITADEL_DATABASE_SQLITE_PATH=/data/zitadel.db
|
||||||
|
# First instance: Admin user
|
||||||
|
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_USERNAME=$ZITADEL_ADMIN_USERNAME
|
||||||
|
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_PASSWORD=$ZITADEL_ADMIN_PASSWORD
|
||||||
|
- ZITADEL_FIRSTINSTANCE_ORG_HUMAN_PASSWORDCHANGEREQUIRED=true
|
||||||
|
# First instance: Service account for NetBird management
|
||||||
|
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_USERNAME=netbird-service
|
||||||
|
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINE_NAME=NetBird Service Account
|
||||||
|
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_MACHINEKEY_TYPE=1
|
||||||
|
- ZITADEL_FIRSTINSTANCE_ORG_MACHINE_PAT_EXPIRATIONDATE=$ZITADEL_PAT_EXPIRATION
|
||||||
|
volumes:
|
||||||
|
- netbird-zitadel-data:/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "/app/zitadel", "ready"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
start_period: 30s
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
$MGMT_VOLUMENAME:
|
$MGMT_VOLUMENAME:
|
||||||
$SIGNAL_VOLUMENAME:
|
$SIGNAL_VOLUMENAME:
|
||||||
$LETSENCRYPT_VOLUMENAME:
|
netbird-caddy-data:
|
||||||
|
netbird-zitadel-data:
|
||||||
|
|
||||||
|
networks:
|
||||||
|
netbird:
|
||||||
|
|||||||
@@ -45,18 +45,18 @@
|
|||||||
"Engine": "$NETBIRD_STORE_CONFIG_ENGINE"
|
"Engine": "$NETBIRD_STORE_CONFIG_ENGINE"
|
||||||
},
|
},
|
||||||
"HttpConfig": {
|
"HttpConfig": {
|
||||||
"Address": "0.0.0.0:$NETBIRD_MGMT_API_PORT",
|
"Address": "0.0.0.0:80",
|
||||||
"AuthIssuer": "$NETBIRD_AUTH_AUTHORITY",
|
"AuthIssuer": "$NETBIRD_AUTH_AUTHORITY",
|
||||||
"AuthAudience": "$NETBIRD_AUTH_AUDIENCE",
|
"AuthAudience": "$NETBIRD_AUTH_CLIENT_ID",
|
||||||
"AuthKeysLocation": "$NETBIRD_AUTH_JWT_CERTS",
|
"AuthKeysLocation": "$NETBIRD_AUTH_JWT_CERTS",
|
||||||
"AuthUserIDClaim": "$NETBIRD_AUTH_USER_ID_CLAIM",
|
"AuthUserIDClaim": "$NETBIRD_AUTH_USER_ID_CLAIM",
|
||||||
"CertFile":"$NETBIRD_MGMT_API_CERT_FILE",
|
"CertFile": "",
|
||||||
"CertKey":"$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
"CertKey": "",
|
||||||
"IdpSignKeyRefreshEnabled": $NETBIRD_MGMT_IDP_SIGNKEY_REFRESH,
|
"IdpSignKeyRefreshEnabled": true,
|
||||||
"OIDCConfigEndpoint":"$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT"
|
"OIDCConfigEndpoint": "$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT"
|
||||||
},
|
},
|
||||||
"IdpManagerConfig": {
|
"IdpManagerConfig": {
|
||||||
"ManagerType": "$NETBIRD_MGMT_IDP",
|
"ManagerType": "zitadel",
|
||||||
"ClientConfig": {
|
"ClientConfig": {
|
||||||
"Issuer": "$NETBIRD_AUTH_AUTHORITY",
|
"Issuer": "$NETBIRD_AUTH_AUTHORITY",
|
||||||
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
||||||
@@ -64,40 +64,28 @@
|
|||||||
"ClientSecret": "$NETBIRD_IDP_MGMT_CLIENT_SECRET",
|
"ClientSecret": "$NETBIRD_IDP_MGMT_CLIENT_SECRET",
|
||||||
"GrantType": "client_credentials"
|
"GrantType": "client_credentials"
|
||||||
},
|
},
|
||||||
"ExtraConfig": $NETBIRD_IDP_MGMT_EXTRA_CONFIG,
|
"ExtraConfig": {
|
||||||
"Auth0ClientCredentials": null,
|
"ManagementEndpoint": "$ZITADEL_MANAGEMENT_ENDPOINT"
|
||||||
"AzureClientCredentials": null,
|
}
|
||||||
"KeycloakClientCredentials": null,
|
},
|
||||||
"ZitadelClientCredentials": null
|
|
||||||
},
|
|
||||||
"DeviceAuthorizationFlow": {
|
"DeviceAuthorizationFlow": {
|
||||||
"Provider": "$NETBIRD_AUTH_DEVICE_AUTH_PROVIDER",
|
"Provider": "hosted",
|
||||||
"ProviderConfig": {
|
"ProviderConfig": {
|
||||||
"Audience": "$NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE",
|
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||||
"AuthorizationEndpoint": "",
|
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||||
"Domain": "$NETBIRD_AUTH0_DOMAIN",
|
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
||||||
"ClientID": "$NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID",
|
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT",
|
||||||
"ClientSecret": "",
|
"Scope": "openid"
|
||||||
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
}
|
||||||
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT",
|
|
||||||
"Scope": "$NETBIRD_AUTH_DEVICE_AUTH_SCOPE",
|
|
||||||
"UseIDToken": $NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN,
|
|
||||||
"RedirectURLs": null
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"PKCEAuthorizationFlow": {
|
"PKCEAuthorizationFlow": {
|
||||||
"ProviderConfig": {
|
"ProviderConfig": {
|
||||||
"Audience": "$NETBIRD_AUTH_PKCE_AUDIENCE",
|
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||||
"ClientID": "$NETBIRD_AUTH_CLIENT_ID",
|
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
|
||||||
"ClientSecret": "$NETBIRD_AUTH_CLIENT_SECRET",
|
|
||||||
"Domain": "",
|
|
||||||
"AuthorizationEndpoint": "$NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT",
|
"AuthorizationEndpoint": "$NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT",
|
||||||
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
||||||
"Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES",
|
"Scope": "openid profile email offline_access",
|
||||||
"RedirectURLs": [$NETBIRD_AUTH_PKCE_REDIRECT_URLS],
|
"RedirectURLs": ["http://localhost:53000/", "http://localhost:54000/"]
|
||||||
"UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN,
|
|
||||||
"DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN,
|
|
||||||
"LoginFlag": $NETBIRD_AUTH_PKCE_LOGIN_FLAG
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,117 +1,65 @@
|
|||||||
## example file, you can copy this file to setup.env and update its values
|
# NetBird Self-Hosted Setup Configuration
|
||||||
##
|
# Copy this file to setup.env and configure the required values
|
||||||
|
|
||||||
# Image tags
|
# -------------------------------------------
|
||||||
# you can force specific tags for each component; will be set to latest if empty
|
# Required: Domain Configuration
|
||||||
|
# -------------------------------------------
|
||||||
|
# Your NetBird domain (e.g., netbird.mydomain.com)
|
||||||
|
NETBIRD_DOMAIN=""
|
||||||
|
|
||||||
|
# -------------------------------------------
|
||||||
|
# Optional: Image Tags
|
||||||
|
# -------------------------------------------
|
||||||
|
# Leave empty to use 'latest' for all components
|
||||||
NETBIRD_DASHBOARD_TAG=""
|
NETBIRD_DASHBOARD_TAG=""
|
||||||
NETBIRD_SIGNAL_TAG=""
|
NETBIRD_SIGNAL_TAG=""
|
||||||
NETBIRD_MANAGEMENT_TAG=""
|
NETBIRD_MANAGEMENT_TAG=""
|
||||||
COTURN_TAG=""
|
COTURN_TAG=""
|
||||||
NETBIRD_RELAY_TAG=""
|
NETBIRD_RELAY_TAG=""
|
||||||
|
|
||||||
# Dashboard domain. e.g. app.mydomain.com
|
# Zitadel version (default: v2.64.1)
|
||||||
NETBIRD_DOMAIN=""
|
ZITADEL_TAG=""
|
||||||
|
|
||||||
# TURN server domain. e.g. turn.mydomain.com
|
# -------------------------------------------
|
||||||
# if not specified it will assume NETBIRD_DOMAIN
|
# Optional: TURN Server Configuration
|
||||||
|
# -------------------------------------------
|
||||||
|
# TURN server domain (defaults to NETBIRD_DOMAIN)
|
||||||
NETBIRD_TURN_DOMAIN=""
|
NETBIRD_TURN_DOMAIN=""
|
||||||
|
|
||||||
# TURN server public IP address
|
# TURN server public IP address
|
||||||
# required for a connection involving peers in
|
# Required for peers behind NAT to connect
|
||||||
# the same network as the server and external peers
|
|
||||||
# usually matches the IP for the domain set in NETBIRD_TURN_DOMAIN
|
|
||||||
NETBIRD_TURN_EXTERNAL_IP=""
|
NETBIRD_TURN_EXTERNAL_IP=""
|
||||||
|
|
||||||
# -------------------------------------------
|
# -------------------------------------------
|
||||||
# OIDC
|
# Optional: Database Configuration
|
||||||
# e.g., https://example.eu.auth0.com/.well-known/openid-configuration
|
|
||||||
# -------------------------------------------
|
# -------------------------------------------
|
||||||
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
|
# Store engine: sqlite (default), postgres, or mysql
|
||||||
# The default setting is to transmit the audience to the IDP during authorization. However,
|
NETBIRD_STORE_CONFIG_ENGINE=""
|
||||||
# if your IDP does not have this capability, you can turn this off by setting it to false.
|
|
||||||
#NETBIRD_DASH_AUTH_USE_AUDIENCE=false
|
# For PostgreSQL:
|
||||||
NETBIRD_AUTH_AUDIENCE=""
|
# NETBIRD_STORE_ENGINE_POSTGRES_DSN="host=<HOST> user=<USER> password=<PASS> dbname=<DB> port=5432"
|
||||||
# e.g. netbird-client
|
|
||||||
NETBIRD_AUTH_CLIENT_ID=""
|
# For MySQL:
|
||||||
# indicates the scopes that will be requested to the IDP
|
# NETBIRD_STORE_ENGINE_MYSQL_DSN="<user>:<pass>@tcp(127.0.0.1:3306)/<db>"
|
||||||
NETBIRD_AUTH_SUPPORTED_SCOPES=""
|
|
||||||
# NETBIRD_AUTH_CLIENT_SECRET is required only by Google workspace.
|
|
||||||
# NETBIRD_AUTH_CLIENT_SECRET=""
|
|
||||||
# if you want to use a custom claim for the user ID instead of 'sub', set it here
|
|
||||||
# NETBIRD_AUTH_USER_ID_CLAIM=""
|
|
||||||
# indicates whether to use Auth0 or not: true or false
|
|
||||||
NETBIRD_USE_AUTH0="false"
|
|
||||||
# if your IDP provider doesn't support fragmented URIs, configure custom
|
|
||||||
# redirect and silent redirect URIs, these will be concatenated into your NETBIRD_DOMAIN domain.
|
|
||||||
# NETBIRD_AUTH_REDIRECT_URI="/peers"
|
|
||||||
# NETBIRD_AUTH_SILENT_REDIRECT_URI="/add-peers"
|
|
||||||
# Updates the preference to use id tokens instead of access token on dashboard
|
|
||||||
# Okta and Gitlab IDPs can benefit from this
|
|
||||||
# NETBIRD_TOKEN_SOURCE="idToken"
|
|
||||||
# -------------------------------------------
|
# -------------------------------------------
|
||||||
# OIDC Device Authorization Flow
|
# Optional: Extra Settings
|
||||||
# -------------------------------------------
|
# -------------------------------------------
|
||||||
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"
|
# Disable anonymous metrics (default: false)
|
||||||
NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID=""
|
|
||||||
# Some IDPs requires different audience, scopes and to use id token for device authorization flow
|
|
||||||
# you can customize here:
|
|
||||||
NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
|
|
||||||
NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid"
|
|
||||||
NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN=false
|
|
||||||
# -------------------------------------------
|
|
||||||
# OIDC PKCE Authorization Flow
|
|
||||||
# -------------------------------------------
|
|
||||||
# Comma separated port numbers. if already in use, PKCE flow will choose an available port from the list as an alternative
|
|
||||||
# eg. 53000,54000
|
|
||||||
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS="53000"
|
|
||||||
# -------------------------------------------
|
|
||||||
# IDP Management
|
|
||||||
# -------------------------------------------
|
|
||||||
# eg. zitadel, auth0, azure, keycloak
|
|
||||||
NETBIRD_MGMT_IDP="none"
|
|
||||||
# Some IDPs requires different client id and client secret for management api
|
|
||||||
NETBIRD_IDP_MGMT_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
|
|
||||||
NETBIRD_IDP_MGMT_CLIENT_SECRET=""
|
|
||||||
# Required when setting up with Keycloak "https://<YOUR_KEYCLOAK_HOST_AND_PORT>/admin/realms/netbird"
|
|
||||||
# NETBIRD_IDP_MGMT_EXTRA_ADMIN_ENDPOINT=
|
|
||||||
# With some IDPs may be needed enabling automatic refresh of signing keys on expire
|
|
||||||
# NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=false
|
|
||||||
# NETBIRD_IDP_MGMT_EXTRA_ variables. See https://docs.netbird.io/selfhosted/identity-providers for more information about your IDP of choice.
|
|
||||||
# -------------------------------------------
|
|
||||||
# Letsencrypt
|
|
||||||
# -------------------------------------------
|
|
||||||
# Disable letsencrypt
|
|
||||||
# if disabled, cannot use HTTPS anymore and requires setting up a reverse-proxy to do it instead
|
|
||||||
NETBIRD_DISABLE_LETSENCRYPT=false
|
|
||||||
# e.g. hello@mydomain.com
|
|
||||||
NETBIRD_LETSENCRYPT_EMAIL=""
|
|
||||||
# -------------------------------------------
|
|
||||||
# Extra settings
|
|
||||||
# -------------------------------------------
|
|
||||||
# Disable anonymous metrics collection, see more information at https://netbird.io/docs/FAQ/metrics-collection
|
|
||||||
NETBIRD_DISABLE_ANONYMOUS_METRICS=false
|
NETBIRD_DISABLE_ANONYMOUS_METRICS=false
|
||||||
# DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted
|
|
||||||
|
# DNS domain for peer resolution (default: netbird.selfhosted)
|
||||||
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted
|
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted
|
||||||
# Disable default all-to-all policy for new accounts
|
|
||||||
|
# Disable default all-to-all policy for new accounts (default: false)
|
||||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=false
|
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=false
|
||||||
|
|
||||||
# -------------------------------------------
|
# -------------------------------------------
|
||||||
# Relay settings
|
# Advanced: Zitadel Client IDs
|
||||||
# -------------------------------------------
|
# -------------------------------------------
|
||||||
# Relay server domain. e.g. relay.mydomain.com
|
# These are auto-generated by Zitadel on first boot
|
||||||
# if not specified it will assume NETBIRD_DOMAIN
|
# Only set these if migrating from an existing Zitadel setup
|
||||||
NETBIRD_RELAY_DOMAIN=""
|
# NETBIRD_AUTH_CLIENT_ID=""
|
||||||
|
# NETBIRD_AUTH_CLIENT_ID_CLI=""
|
||||||
# Relay server connection port. If none is supplied
|
# NETBIRD_IDP_MGMT_CLIENT_ID=""
|
||||||
# it will default to 33080
|
# NETBIRD_IDP_MGMT_CLIENT_SECRET=""
|
||||||
# should be updated to match TLS-port of reverse proxy when netbird is running behind reverse proxy
|
|
||||||
NETBIRD_RELAY_PORT=""
|
|
||||||
|
|
||||||
# Management API connecting port. If none is supplied
|
|
||||||
# it will default to 33073
|
|
||||||
# should be updated to match TLS-port of reverse proxy when netbird is running behind reverse proxy
|
|
||||||
NETBIRD_MGMT_API_PORT=""
|
|
||||||
|
|
||||||
# Signal service connecting port. If none is supplied
|
|
||||||
# it will default to 10000
|
|
||||||
# should be updated to match TLS-port of reverse proxy when netbird is running behind reverse proxy
|
|
||||||
NETBIRD_SIGNAL_PORT=""
|
|
||||||
|
|||||||
@@ -587,42 +587,40 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context, store cache
|
|||||||
time.Sleep(delay)
|
time.Sleep(delay)
|
||||||
}
|
}
|
||||||
|
|
||||||
userData, err := am.idpManager.GetAllAccounts(ctx)
|
// Get all users from IdP
|
||||||
|
idpUsers, err := am.idpManager.GetAllUsers(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Infof("%d entries received from IdP management", len(userData))
|
log.WithContext(ctx).Infof("%d users received from IdP management", len(idpUsers))
|
||||||
|
|
||||||
// If the Identity Provider does not support writing AppMetadata,
|
// Create a map for quick lookup of IdP users by ID
|
||||||
// in cases like this, we expect it to return all users in an "unset" field.
|
idpUserMap := make(map[string]*idp.UserData, len(idpUsers))
|
||||||
// We iterate over the users in the "unset" field, look up their AccountID in our store, and
|
for _, user := range idpUsers {
|
||||||
// update their AppMetadata with the AccountID.
|
idpUserMap[user.ID] = user
|
||||||
if unsetData, ok := userData[idp.UnsetAccountID]; ok {
|
}
|
||||||
for _, user := range unsetData {
|
|
||||||
accountID, err := am.Store.GetAccountByUser(ctx, user.ID)
|
// Group IdP users by their account ID from NetBird's database
|
||||||
if err == nil {
|
// NetBird DB is the source of truth for account membership
|
||||||
data := userData[accountID.Id]
|
accountUsers := make(map[string][]*idp.UserData)
|
||||||
if data == nil {
|
for _, idpUser := range idpUsers {
|
||||||
data = make([]*idp.UserData, 0, 1)
|
account, err := am.Store.GetAccountByUser(ctx, idpUser.ID)
|
||||||
}
|
if err != nil {
|
||||||
|
// User exists in IdP but not in NetBird - skip
|
||||||
user.AppMetadata.WTAccountID = accountID.Id
|
continue
|
||||||
|
}
|
||||||
userData[accountID.Id] = append(data, user)
|
accountUsers[account.Id] = append(accountUsers[account.Id], idpUser)
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
delete(userData, idp.UnsetAccountID)
|
|
||||||
|
|
||||||
rcvdUsers := 0
|
rcvdUsers := 0
|
||||||
for accountID, users := range userData {
|
for accountID, users := range accountUsers {
|
||||||
rcvdUsers += len(users)
|
rcvdUsers += len(users)
|
||||||
err = am.cacheManager.Set(am.ctx, accountID, users, cacheEntryExpiration())
|
err = am.cacheManager.Set(am.ctx, accountID, users, cacheEntryExpiration())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData))
|
log.WithContext(ctx).Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(accountUsers))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -742,10 +740,6 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return account.Id, nil
|
return account.Id, nil
|
||||||
}
|
}
|
||||||
return "", err
|
return "", err
|
||||||
@@ -757,35 +751,6 @@ func isNil(i idp.Manager) bool {
|
|||||||
return i == nil || reflect.ValueOf(i).IsNil()
|
return i == nil || reflect.ValueOf(i).IsNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
|
||||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
|
||||||
if !isNil(am.idpManager) {
|
|
||||||
// user can be nil if it wasn't found (e.g., just created)
|
|
||||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if user != nil && user.AppMetadata.WTAccountID == accountID {
|
|
||||||
// it was already set, so we skip the unnecessary update
|
|
||||||
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
|
|
||||||
accountID, userID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
|
||||||
if err != nil {
|
|
||||||
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
|
|
||||||
}
|
|
||||||
// refresh cache to reflect the update
|
|
||||||
_, err = am.refreshCache(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) (any, []cacheStore.Option, error) {
|
func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) (any, []cacheStore.Option, error) {
|
||||||
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
|
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
|
||||||
accountIDString := fmt.Sprintf("%v", accountID)
|
accountIDString := fmt.Sprintf("%v", accountID)
|
||||||
@@ -797,28 +762,32 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
|
|||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
|
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
|
||||||
|
|
||||||
|
// Get users belonging to this account from NetBird's database
|
||||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
|
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userData, err := am.idpManager.GetAccount(ctx, accountIDString)
|
// Get all users from IdP (we don't have per-account queries anymore)
|
||||||
|
idpUsers, err := am.idpManager.GetAllUsers(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("%d entries received from IdP management for account %s", len(userData), accountIDString)
|
log.WithContext(ctx).Debugf("%d total users in IdP, matching against account %s", len(idpUsers), accountIDString)
|
||||||
|
|
||||||
dataMap := make(map[string]*idp.UserData, len(userData))
|
// Create a map for quick lookup of IdP users by ID
|
||||||
for _, datum := range userData {
|
idpUserMap := make(map[string]*idp.UserData, len(idpUsers))
|
||||||
dataMap[datum.ID] = datum
|
for _, user := range idpUsers {
|
||||||
|
idpUserMap[user.ID] = user
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Match account users against IdP users
|
||||||
matchedUserData := make([]*idp.UserData, 0)
|
matchedUserData := make([]*idp.UserData, 0)
|
||||||
for _, user := range accountUsers {
|
for _, user := range accountUsers {
|
||||||
if user.IsServiceUser {
|
if user.IsServiceUser {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
datum, ok := dataMap[user.Id]
|
datum, ok := idpUserMap[user.Id]
|
||||||
if !ok {
|
if !ok {
|
||||||
log.WithContext(ctx).Warnf("user %s not found in IDP", user.Id)
|
log.WithContext(ctx).Warnf("user %s not found in IDP", user.Id)
|
||||||
continue
|
continue
|
||||||
@@ -972,7 +941,7 @@ func (am *DefaultAccountManager) lookupCache(ctx context.Context, accountUsers m
|
|||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count and user invite status
|
// isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count
|
||||||
func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool {
|
func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool {
|
||||||
userDataMap := make(map[string]*idp.UserData, len(data))
|
userDataMap := make(map[string]*idp.UserData, len(data))
|
||||||
for _, datum := range data {
|
for _, datum := range data {
|
||||||
@@ -980,15 +949,10 @@ func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers
|
|||||||
}
|
}
|
||||||
|
|
||||||
// the accountUsers ID list of non integration users from store, we check if cache has all of them
|
// the accountUsers ID list of non integration users from store, we check if cache has all of them
|
||||||
// as result of for loop knownUsersCount will have number of users are not presented in the cashed
|
// as result of for loop knownUsersCount will have number of users are not presented in the cached
|
||||||
knownUsersCount := len(accountUsers)
|
knownUsersCount := len(accountUsers)
|
||||||
for user, loggedInOnce := range accountUsers {
|
for user := range accountUsers {
|
||||||
if datum, ok := userDataMap[user]; ok {
|
if _, ok := userDataMap[user]; ok {
|
||||||
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
|
|
||||||
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
|
|
||||||
log.WithContext(ctx).Infof("user %s has a pending invite and has logged in once, cache invalid", user)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
knownUsersCount--
|
knownUsersCount--
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -1078,12 +1042,6 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// we should register the account ID to this user's metadata in our IDP manager
|
|
||||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, userAccountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1110,11 +1068,6 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, newAccount.Id)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, newAccount.Id, activity.UserJoined, nil)
|
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, newAccount.Id, activity.UserJoined, nil)
|
||||||
|
|
||||||
return newAccount.Id, nil
|
return newAccount.Id, nil
|
||||||
@@ -1139,11 +1092,6 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if newUser.PendingApproval {
|
if newUser.PendingApproval {
|
||||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, map[string]any{"pending_approval": true})
|
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, map[string]any{"pending_approval": true})
|
||||||
} else {
|
} else {
|
||||||
@@ -1155,34 +1103,30 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
|
|||||||
|
|
||||||
// redeemInvite checks whether user has been invited and redeems the invite
|
// redeemInvite checks whether user has been invited and redeems the invite
|
||||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
|
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
|
||||||
// only possible with the enabled IdP manager
|
// Get user from NetBird's database (source of truth for authorization data)
|
||||||
if am.idpManager == nil {
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||||
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return status.Errorf(status.NotFound, "user %s not found in the IdP", userID)
|
return status.Errorf(status.NotFound, "user %s not found", userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
|
// Check if user has a pending invite that needs to be redeemed
|
||||||
|
if user.PendingInvite {
|
||||||
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID)
|
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID)
|
||||||
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
|
|
||||||
// Our job is to just reload cache.
|
// Update user to mark invite as redeemed
|
||||||
go func() {
|
user.PendingInvite = false
|
||||||
_, err = am.refreshCache(ctx, accountID)
|
err = am.Store.SaveUser(ctx, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID)
|
log.WithContext(ctx).Warnf("failed to redeem invite for user %s: %v", userID, err)
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID)
|
|
||||||
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
|
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", userID, accountID)
|
||||||
}()
|
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
|
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/handlers/connectors"
|
||||||
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
|
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/http/handlers/events"
|
"github.com/netbirdio/netbird/management/server/http/handlers/events"
|
||||||
"github.com/netbirdio/netbird/management/server/http/handlers/groups"
|
"github.com/netbirdio/netbird/management/server/http/handlers/groups"
|
||||||
@@ -134,6 +135,7 @@ func NewAPIHandler(
|
|||||||
dns.AddEndpoints(accountManager, router)
|
dns.AddEndpoints(accountManager, router)
|
||||||
events.AddEndpoints(accountManager, router)
|
events.AddEndpoints(accountManager, router)
|
||||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
|
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
|
||||||
|
connectors.AddEndpoints(accountManager, router)
|
||||||
|
|
||||||
return rootRouter, nil
|
return rootRouter, nil
|
||||||
}
|
}
|
||||||
|
|||||||
590
management/server/http/handlers/connectors/connectors_handler.go
Normal file
590
management/server/http/handlers/connectors/connectors_handler.go
Normal file
@@ -0,0 +1,590 @@
|
|||||||
|
package connectors
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// API request/response types
|
||||||
|
|
||||||
|
// ConnectorResponse represents an IdP connector in API responses
|
||||||
|
type ConnectorResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Issuer string `json:"issuer,omitempty"`
|
||||||
|
Servers []string `json:"servers,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OIDCConnectorRequest represents a request to create an OIDC connector
|
||||||
|
type OIDCConnectorRequest struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
ClientSecret string `json:"client_secret"`
|
||||||
|
Scopes []string `json:"scopes,omitempty"`
|
||||||
|
IsAutoCreation bool `json:"is_auto_creation,omitempty"`
|
||||||
|
IsAutoUpdate bool `json:"is_auto_update,omitempty"`
|
||||||
|
IsCreationAllowed bool `json:"is_creation_allowed,omitempty"`
|
||||||
|
IsLinkingAllowed bool `json:"is_linking_allowed,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LDAPConnectorRequest represents a request to create an LDAP connector
|
||||||
|
type LDAPConnectorRequest struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Servers []string `json:"servers"`
|
||||||
|
StartTLS bool `json:"start_tls,omitempty"`
|
||||||
|
BaseDN string `json:"base_dn"`
|
||||||
|
BindDN string `json:"bind_dn"`
|
||||||
|
BindPassword string `json:"bind_password"`
|
||||||
|
UserBase string `json:"user_base,omitempty"`
|
||||||
|
UserObjectClass []string `json:"user_object_class,omitempty"`
|
||||||
|
UserFilters []string `json:"user_filters,omitempty"`
|
||||||
|
Timeout string `json:"timeout,omitempty"`
|
||||||
|
Attributes *LDAPAttributesRequest `json:"attributes,omitempty"`
|
||||||
|
IsAutoCreation bool `json:"is_auto_creation,omitempty"`
|
||||||
|
IsAutoUpdate bool `json:"is_auto_update,omitempty"`
|
||||||
|
IsCreationAllowed bool `json:"is_creation_allowed,omitempty"`
|
||||||
|
IsLinkingAllowed bool `json:"is_linking_allowed,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LDAPAttributesRequest maps LDAP attributes to user fields
|
||||||
|
type LDAPAttributesRequest struct {
|
||||||
|
IDAttribute string `json:"id_attribute,omitempty"`
|
||||||
|
FirstNameAttribute string `json:"first_name_attribute,omitempty"`
|
||||||
|
LastNameAttribute string `json:"last_name_attribute,omitempty"`
|
||||||
|
DisplayNameAttribute string `json:"display_name_attribute,omitempty"`
|
||||||
|
EmailAttribute string `json:"email_attribute,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAMLConnectorRequest represents a request to create a SAML connector
|
||||||
|
type SAMLConnectorRequest struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
MetadataXML string `json:"metadata_xml,omitempty"`
|
||||||
|
MetadataURL string `json:"metadata_url,omitempty"`
|
||||||
|
Binding string `json:"binding,omitempty"`
|
||||||
|
WithSignedRequest bool `json:"with_signed_request,omitempty"`
|
||||||
|
NameIDFormat string `json:"name_id_format,omitempty"`
|
||||||
|
IsAutoCreation bool `json:"is_auto_creation,omitempty"`
|
||||||
|
IsAutoUpdate bool `json:"is_auto_update,omitempty"`
|
||||||
|
IsCreationAllowed bool `json:"is_creation_allowed,omitempty"`
|
||||||
|
IsLinkingAllowed bool `json:"is_linking_allowed,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// handler handles HTTP requests for IdP connectors
|
||||||
|
type handler struct {
|
||||||
|
accountManager account.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddEndpoints registers the connector endpoints to the router
|
||||||
|
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||||
|
h := &handler{accountManager: accountManager}
|
||||||
|
|
||||||
|
router.HandleFunc("/connectors", h.listConnectors).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/connectors/{connectorId}", h.getConnector).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/connectors/{connectorId}", h.deleteConnector).Methods("DELETE", "OPTIONS")
|
||||||
|
router.HandleFunc("/connectors/oidc", h.addOIDCConnector).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/connectors/ldap", h.addLDAPConnector).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/connectors/saml", h.addSAMLConnector).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/connectors/{connectorId}/activate", h.activateConnector).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/connectors/{connectorId}/deactivate", h.deleteConnector).Methods("POST", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getConnectorManager retrieves the connector manager from the IdP manager
|
||||||
|
func (h *handler) getConnectorManager() (idp.ConnectorManager, error) {
|
||||||
|
idpManager := h.accountManager.GetIdpManager()
|
||||||
|
if idpManager == nil {
|
||||||
|
return nil, status.Errorf(status.PreconditionFailed, "IdP manager is not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
connectorManager, ok := idpManager.(idp.ConnectorManager)
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Errorf(status.PreconditionFailed, "IdP manager does not support connector management")
|
||||||
|
}
|
||||||
|
|
||||||
|
return connectorManager, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// listConnectors returns all configured IdP connectors
|
||||||
|
func (h *handler) listConnectors(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only admins can manage connectors
|
||||||
|
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.HasAdminPower() {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connectorManager, err := h.getConnectorManager()
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connectors, err := connectorManager.ListConnectors(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to list connectors: %v", err), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := make([]*ConnectorResponse, 0, len(connectors))
|
||||||
|
for _, c := range connectors {
|
||||||
|
response = append(response, toConnectorResponse(c))
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getConnector returns a specific connector by ID
|
||||||
|
func (h *handler) getConnector(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.HasAdminPower() {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
connectorID := vars["connectorId"]
|
||||||
|
if connectorID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "connector ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connectorManager, err := h.getConnectorManager()
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connector, err := connectorManager.GetConnector(r.Context(), connectorID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.NotFound, "connector not found: %v", err), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, toConnectorResponse(connector))
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteConnector removes a connector
|
||||||
|
func (h *handler) deleteConnector(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodDelete {
|
||||||
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.HasAdminPower() {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
connectorID := vars["connectorId"]
|
||||||
|
if connectorID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "connector ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connectorManager, err := h.getConnectorManager()
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := connectorManager.DeleteConnector(r.Context(), connectorID); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to delete connector: %v", err), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// addOIDCConnector creates a new OIDC connector
|
||||||
|
func (h *handler) addOIDCConnector(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.HasAdminPower() {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req OIDCConnectorRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "name is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Issuer == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "issuer is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.ClientID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "client_id is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.ClientSecret == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "client_secret is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connectorManager, err := h.getConnectorManager()
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
config := idp.OIDCConnectorConfig{
|
||||||
|
Name: req.Name,
|
||||||
|
Issuer: req.Issuer,
|
||||||
|
ClientID: req.ClientID,
|
||||||
|
ClientSecret: req.ClientSecret,
|
||||||
|
Scopes: req.Scopes,
|
||||||
|
IsAutoCreation: req.IsAutoCreation,
|
||||||
|
IsAutoUpdate: req.IsAutoUpdate,
|
||||||
|
IsCreationAllowed: req.IsCreationAllowed,
|
||||||
|
IsLinkingAllowed: req.IsLinkingAllowed,
|
||||||
|
}
|
||||||
|
|
||||||
|
connector, err := connectorManager.AddOIDCConnector(r.Context(), config)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to add OIDC connector: %v", err), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
util.WriteJSONObject(r.Context(), w, toConnectorResponse(connector))
|
||||||
|
}
|
||||||
|
|
||||||
|
// addLDAPConnector creates a new LDAP connector
|
||||||
|
func (h *handler) addLDAPConnector(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.HasAdminPower() {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req LDAPConnectorRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "name is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(req.Servers) == 0 {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "at least one server is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.BaseDN == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "base_dn is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.BindDN == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "bind_dn is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.BindPassword == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "bind_password is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connectorManager, err := h.getConnectorManager()
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
config := idp.LDAPConnectorConfig{
|
||||||
|
Name: req.Name,
|
||||||
|
Servers: req.Servers,
|
||||||
|
StartTLS: req.StartTLS,
|
||||||
|
BaseDN: req.BaseDN,
|
||||||
|
BindDN: req.BindDN,
|
||||||
|
BindPassword: req.BindPassword,
|
||||||
|
UserBase: req.UserBase,
|
||||||
|
UserObjectClass: req.UserObjectClass,
|
||||||
|
UserFilters: req.UserFilters,
|
||||||
|
Timeout: req.Timeout,
|
||||||
|
IsAutoCreation: req.IsAutoCreation,
|
||||||
|
IsAutoUpdate: req.IsAutoUpdate,
|
||||||
|
IsCreationAllowed: req.IsCreationAllowed,
|
||||||
|
IsLinkingAllowed: req.IsLinkingAllowed,
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Attributes != nil {
|
||||||
|
config.Attributes = idp.LDAPAttributes{
|
||||||
|
IDAttribute: req.Attributes.IDAttribute,
|
||||||
|
FirstNameAttribute: req.Attributes.FirstNameAttribute,
|
||||||
|
LastNameAttribute: req.Attributes.LastNameAttribute,
|
||||||
|
DisplayNameAttribute: req.Attributes.DisplayNameAttribute,
|
||||||
|
EmailAttribute: req.Attributes.EmailAttribute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
connector, err := connectorManager.AddLDAPConnector(r.Context(), config)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to add LDAP connector: %v", err), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
util.WriteJSONObject(r.Context(), w, toConnectorResponse(connector))
|
||||||
|
}
|
||||||
|
|
||||||
|
// addSAMLConnector creates a new SAML connector
|
||||||
|
func (h *handler) addSAMLConnector(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.HasAdminPower() {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req SAMLConnectorRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "name is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.MetadataXML == "" && req.MetadataURL == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "either metadata_xml or metadata_url is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connectorManager, err := h.getConnectorManager()
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
config := idp.SAMLConnectorConfig{
|
||||||
|
Name: req.Name,
|
||||||
|
MetadataXML: req.MetadataXML,
|
||||||
|
MetadataURL: req.MetadataURL,
|
||||||
|
Binding: req.Binding,
|
||||||
|
WithSignedRequest: req.WithSignedRequest,
|
||||||
|
NameIDFormat: req.NameIDFormat,
|
||||||
|
IsAutoCreation: req.IsAutoCreation,
|
||||||
|
IsAutoUpdate: req.IsAutoUpdate,
|
||||||
|
IsCreationAllowed: req.IsCreationAllowed,
|
||||||
|
IsLinkingAllowed: req.IsLinkingAllowed,
|
||||||
|
}
|
||||||
|
|
||||||
|
connector, err := connectorManager.AddSAMLConnector(r.Context(), config)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to add SAML connector: %v", err), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
util.WriteJSONObject(r.Context(), w, toConnectorResponse(connector))
|
||||||
|
}
|
||||||
|
|
||||||
|
// activateConnector adds the connector to the login policy
|
||||||
|
func (h *handler) activateConnector(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.HasAdminPower() {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
connectorID := vars["connectorId"]
|
||||||
|
if connectorID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "connector ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connectorManager, err := h.getConnectorManager()
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := connectorManager.ActivateConnector(r.Context(), connectorID); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to activate connector: %v", err), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// deactivateConnector removes the connector from the login policy
|
||||||
|
func (h *handler) deactivateConnector(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.HasAdminPower() {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only admins can manage IdP connectors"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
connectorID := vars["connectorId"]
|
||||||
|
if connectorID == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "connector ID is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connectorManager, err := h.getConnectorManager()
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := connectorManager.DeactivateConnector(r.Context(), connectorID); err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "failed to deactivate connector: %v", err), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// toConnectorResponse converts an idp.Connector to a ConnectorResponse
|
||||||
|
func toConnectorResponse(c *idp.Connector) *ConnectorResponse {
|
||||||
|
return &ConnectorResponse{
|
||||||
|
ID: c.ID,
|
||||||
|
Name: c.Name,
|
||||||
|
Type: string(c.Type),
|
||||||
|
State: c.State,
|
||||||
|
Issuer: c.Issuer,
|
||||||
|
Servers: c.Servers,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,959 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"compress/gzip"
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Auth0Manager auth0 manager client instance
|
|
||||||
type Auth0Manager struct {
|
|
||||||
authIssuer string
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
credentials ManagerCredentials
|
|
||||||
helper ManagerHelper
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auth0ClientConfig auth0 manager client configurations
|
|
||||||
type Auth0ClientConfig struct {
|
|
||||||
Audience string
|
|
||||||
AuthIssuer string
|
|
||||||
ClientID string
|
|
||||||
ClientSecret string
|
|
||||||
GrantType string
|
|
||||||
}
|
|
||||||
|
|
||||||
// auth0JWTRequest payload struct to request a JWT Token
|
|
||||||
type auth0JWTRequest struct {
|
|
||||||
Audience string `json:"audience"`
|
|
||||||
AuthIssuer string `json:"auth_issuer"`
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
ClientSecret string `json:"client_secret"`
|
|
||||||
GrantType string `json:"grant_type"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auth0Credentials auth0 authentication information
|
|
||||||
type Auth0Credentials struct {
|
|
||||||
clientConfig Auth0ClientConfig
|
|
||||||
helper ManagerHelper
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
jwtToken JWTToken
|
|
||||||
mux sync.Mutex
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// createUserRequest is a user create request
|
|
||||||
type createUserRequest struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
AppMeta AppMetadata `json:"app_metadata"`
|
|
||||||
Connection string `json:"connection"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
VerifyEmail bool `json:"verify_email"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// userExportJobRequest is a user export request struct
|
|
||||||
type userExportJobRequest struct {
|
|
||||||
Format string `json:"format"`
|
|
||||||
Fields []map[string]string `json:"fields"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// userExportJobResponse is a user export response struct
|
|
||||||
type userExportJobResponse struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
ConnectionID string `json:"connection_id"`
|
|
||||||
Format string `json:"format"`
|
|
||||||
Limit int `json:"limit"`
|
|
||||||
Connection string `json:"connection"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// userExportJobStatusResponse is a user export status response struct
|
|
||||||
type userExportJobStatusResponse struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
ConnectionID string `json:"connection_id"`
|
|
||||||
Format string `json:"format"`
|
|
||||||
Limit int `json:"limit"`
|
|
||||||
Location string `json:"location"`
|
|
||||||
Connection string `json:"connection"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// userVerificationJobRequest is a user verification request struct
|
|
||||||
type userVerificationJobRequest struct {
|
|
||||||
UserID string `json:"user_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// auth0Profile represents an Auth0 user profile response
|
|
||||||
type auth0Profile struct {
|
|
||||||
AccountID string `json:"wt_account_id"`
|
|
||||||
PendingInvite bool `json:"wt_pending_invite"`
|
|
||||||
UserID string `json:"user_id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
CreatedAt string `json:"created_at"`
|
|
||||||
LastLogin string `json:"last_login"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connections represents a single Auth0 connection
|
|
||||||
// https://auth0.com/docs/api/management/v2/connections/get-connections
|
|
||||||
type Connection struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
DisplayName string `json:"display_name"`
|
|
||||||
IsDomainConnection bool `json:"is_domain_connection"`
|
|
||||||
Realms []string `json:"realms"`
|
|
||||||
Metadata map[string]string `json:"metadata"`
|
|
||||||
Options ConnectionOptions `json:"options"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConnectionOptions struct {
|
|
||||||
DomainAliases []string `json:"domain_aliases"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAuth0Manager creates a new instance of the Auth0Manager
|
|
||||||
func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
httpTransport.MaxIdleConns = 5
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
Transport: httpTransport,
|
|
||||||
}
|
|
||||||
helper := JsonParser{}
|
|
||||||
|
|
||||||
if config.AuthIssuer == "" {
|
|
||||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, AuthIssuer is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientID is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.ClientSecret == "" {
|
|
||||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientSecret is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Audience == "" {
|
|
||||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, Audience is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.GrantType == "" {
|
|
||||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, GrantType is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
credentials := &Auth0Credentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: httpClient,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Auth0Manager{
|
|
||||||
authIssuer: config.AuthIssuer,
|
|
||||||
credentials: credentials,
|
|
||||||
httpClient: httpClient,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from Auth0
|
|
||||||
func (c *Auth0Credentials) jwtStillValid() bool {
|
|
||||||
return !c.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(c.jwtToken.expiresInTime)
|
|
||||||
}
|
|
||||||
|
|
||||||
// requestJWTToken performs request to get jwt token
|
|
||||||
func (c *Auth0Credentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
|
||||||
var res *http.Response
|
|
||||||
reqURL := c.clientConfig.AuthIssuer + "/oauth/token"
|
|
||||||
|
|
||||||
p, err := c.helper.Marshal(auth0JWTRequest(c.clientConfig))
|
|
||||||
if err != nil {
|
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
payload := strings.NewReader(string(p))
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", reqURL, payload)
|
|
||||||
if err != nil {
|
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Add("content-type", "application/json")
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debug("requesting new jwt token for idp manager")
|
|
||||||
|
|
||||||
res, err = c.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if c.appMetrics != nil {
|
|
||||||
c.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
return res, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
return res, fmt.Errorf("unable to get token, statusCode %d", res.StatusCode)
|
|
||||||
}
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
|
||||||
func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
|
||||||
jwtToken := JWTToken{}
|
|
||||||
body, err := io.ReadAll(rawBody)
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.helper.Unmarshal(body, &jwtToken)
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
|
||||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
|
||||||
}
|
|
||||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
// Exp maps into exp from jwt token
|
|
||||||
var IssuedAt struct{ Exp int64 }
|
|
||||||
err = json.Unmarshal(data, &IssuedAt)
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
|
||||||
|
|
||||||
return jwtToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authenticate retrieves access token to use the Auth0 Management API
|
|
||||||
func (c *Auth0Credentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
|
||||||
c.mux.Lock()
|
|
||||||
defer c.mux.Unlock()
|
|
||||||
|
|
||||||
if c.appMetrics != nil {
|
|
||||||
c.appMetrics.IDPMetrics().CountAuthenticate()
|
|
||||||
}
|
|
||||||
|
|
||||||
// If jwtToken has an expires time and we have enough time to do a request return immediately
|
|
||||||
if c.jwtStillValid() {
|
|
||||||
return c.jwtToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := c.requestJWTToken(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return c.jwtToken, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = res.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("error while closing get jwt token response body: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
jwtToken, err := c.parseRequestJWTResponse(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return c.jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.jwtToken = jwtToken
|
|
||||||
|
|
||||||
return c.jwtToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func batchRequestUsersURL(authIssuer, accountID string, page int, perPage int) (string, url.Values, error) {
|
|
||||||
u, err := url.Parse(authIssuer + "/api/v2/users")
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
q := u.Query()
|
|
||||||
q.Set("page", strconv.Itoa(page))
|
|
||||||
q.Set("search_engine", "v3")
|
|
||||||
q.Set("per_page", strconv.Itoa(perPage))
|
|
||||||
q.Set("q", "app_metadata.wt_account_id:"+accountID)
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
|
|
||||||
return u.String(), q, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func requestByUserIDURL(authIssuer, userID string) string {
|
|
||||||
return authIssuer + "/api/v2/users/" + userID
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccount returns all the users for a given profile. Calls Auth0 API.
|
|
||||||
func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
|
||||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var list []*UserData
|
|
||||||
|
|
||||||
// https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations
|
|
||||||
// auth0 limitation of 1000 users via this endpoint
|
|
||||||
resultsPerPage := 50
|
|
||||||
for page := 0; page < 20; page++ {
|
|
||||||
reqURL, query, err := batchRequestUsersURL(am.authIssuer, accountID, page, resultsPerPage)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, reqURL, strings.NewReader(query.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
|
||||||
req.Header.Add("content-type", "application/json")
|
|
||||||
|
|
||||||
res, err := am.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountGetAccount()
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
return nil, fmt.Errorf("failed requesting user data from IdP %s", string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
var batch []UserData
|
|
||||||
err = json.Unmarshal(body, &batch)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
|
|
||||||
|
|
||||||
err = res.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for user := range batch {
|
|
||||||
list = append(list, &batch[user])
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(batch) == 0 || len(batch) < resultsPerPage {
|
|
||||||
log.WithContext(ctx).Debugf("finished loading users for accountID %s", accountID)
|
|
||||||
return list, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return list, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserDataByID requests user data from auth0 via ID
|
|
||||||
func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
|
||||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
reqURL := requestByUserIDURL(am.authIssuer, userID)
|
|
||||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
|
||||||
req.Header.Add("content-type", "application/json")
|
|
||||||
|
|
||||||
res, err := am.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var userData UserData
|
|
||||||
err = json.Unmarshal(body, &userData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = res.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
return nil, fmt.Errorf("unable to get UserData, statusCode %d", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &userData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateUserAppMetadata updates user app metadata based on userId and metadata map
|
|
||||||
func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error {
|
|
||||||
|
|
||||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
reqURL := am.authIssuer + "/api/v2/users/" + userID
|
|
||||||
|
|
||||||
data, err := am.helper.Marshal(map[string]any{"app_metadata": appMetadata})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
payload := strings.NewReader(string(data))
|
|
||||||
|
|
||||||
req, err := http.NewRequest("PATCH", reqURL, payload)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
|
||||||
req.Header.Add("content-type", "application/json")
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("updating IdP metadata for user %s", userID)
|
|
||||||
|
|
||||||
res, err := am.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = res.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
return fmt.Errorf("unable to update the appMetadata, statusCode %d", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildCreateUserRequestPayload(email, name, accountID, invitedByEmail string) (string, error) {
|
|
||||||
invite := true
|
|
||||||
req := &createUserRequest{
|
|
||||||
Email: email,
|
|
||||||
Name: name,
|
|
||||||
AppMeta: AppMetadata{
|
|
||||||
WTAccountID: accountID,
|
|
||||||
WTPendingInvite: &invite,
|
|
||||||
WTInvitedBy: invitedByEmail,
|
|
||||||
},
|
|
||||||
Connection: "Username-Password-Authentication",
|
|
||||||
Password: GeneratePassword(8, 1, 1, 1),
|
|
||||||
VerifyEmail: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
str, err := json.Marshal(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(str), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildUserExportRequest() (string, error) {
|
|
||||||
req := &userExportJobRequest{}
|
|
||||||
fields := make([]map[string]string, 0)
|
|
||||||
|
|
||||||
for _, field := range []string{"created_at", "last_login", "user_id", "email", "name"} {
|
|
||||||
fields = append(fields, map[string]string{"name": field})
|
|
||||||
}
|
|
||||||
|
|
||||||
fields = append(fields, map[string]string{
|
|
||||||
"name": "app_metadata.wt_account_id",
|
|
||||||
"export_as": "wt_account_id",
|
|
||||||
})
|
|
||||||
|
|
||||||
fields = append(fields, map[string]string{
|
|
||||||
"name": "app_metadata.wt_pending_invite",
|
|
||||||
"export_as": "wt_pending_invite",
|
|
||||||
})
|
|
||||||
|
|
||||||
req.Format = "json"
|
|
||||||
req.Fields = fields
|
|
||||||
|
|
||||||
str, err := json.Marshal(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(str), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *Auth0Manager) createRequest(
|
|
||||||
ctx context.Context, method string, endpoint string, body io.Reader,
|
|
||||||
) (*http.Request, error) {
|
|
||||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
reqURL := am.authIssuer + endpoint
|
|
||||||
|
|
||||||
req, err := http.NewRequest(method, reqURL, body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
|
||||||
|
|
||||||
return req, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *Auth0Manager) createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
|
|
||||||
req, err := am.createRequest(ctx, "POST", endpoint, strings.NewReader(payloadStr))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Add("content-type", "application/json")
|
|
||||||
|
|
||||||
return req, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
|
||||||
// It returns a list of users indexed by accountID.
|
|
||||||
func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
|
||||||
payloadString, err := buildUserExportRequest()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
exportJobReq, err := am.createPostRequest(ctx, "/api/v2/jobs/users-exports", payloadString)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
jobResp, err := am.httpClient.Do(exportJobReq)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = jobResp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if jobResp.StatusCode != 200 {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to update the appMetadata, statusCode %d", jobResp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
var exportJobResp userExportJobResponse
|
|
||||||
|
|
||||||
body, err := io.ReadAll(jobResp.Body)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = am.helper.Unmarshal(body, &exportJobResp)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if exportJobResp.ID == "" {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
|
||||||
|
|
||||||
done, downloadLink, err := am.checkExportJobStatus(ctx, exportJobResp.ID)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("Failed at getting status checks from exportJob; %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if done {
|
|
||||||
return am.downloadProfileExport(ctx, downloadLink)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("failed extracting user profiles from auth0")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list.
|
|
||||||
// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with
|
|
||||||
// the same email but different connections that are considered as separate accounts (e.g., Google and username/password).
|
|
||||||
func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
|
||||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
|
|
||||||
body, err := doGetReq(ctx, am.httpClient, reqURL, jwtToken.AccessToken)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
|
||||||
}
|
|
||||||
|
|
||||||
userResp := []*UserData{}
|
|
||||||
|
|
||||||
err = am.helper.Unmarshal(body, &userResp)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return userResp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateUser creates a new user in Auth0 Idp and sends an invite
|
|
||||||
func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
|
||||||
|
|
||||||
payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req, err := am.createPostRequest(ctx, "/api/v2/users", payloadString)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountCreateUser()
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := am.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("error while closing create user response body: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
var createResp UserData
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = am.helper.Unmarshal(body, &createResp)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if createResp.ID == "" {
|
|
||||||
return nil, fmt.Errorf("couldn't create user: response %v", resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("created user %s in account %s", createResp.ID, accountID)
|
|
||||||
|
|
||||||
return &createResp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InviteUserByID resend invitations to users who haven't activated,
|
|
||||||
// their accounts prior to the expiration period.
|
|
||||||
func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error {
|
|
||||||
userVerificationReq := userVerificationJobRequest{
|
|
||||||
UserID: userID,
|
|
||||||
}
|
|
||||||
|
|
||||||
payload, err := am.helper.Marshal(userVerificationReq)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := am.createPostRequest(ctx, "/api/v2/jobs/verification-email", string(payload))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := am.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("error while closing invite user response body: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return fmt.Errorf("unable to invite user, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteUser from Auth0
|
|
||||||
func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error {
|
|
||||||
req, err := am.createRequest(ctx, http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := am.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("execute delete request: %v", err)
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("close delete request body: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if resp.StatusCode != 204 {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllConnections returns detailed list of all connections filtered by given params.
|
|
||||||
// Note this method is not part of the IDP Manager interface as this is Auth0 specific.
|
|
||||||
func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string) ([]Connection, error) {
|
|
||||||
var connections []Connection
|
|
||||||
|
|
||||||
q := make(url.Values)
|
|
||||||
q.Set("strategy", strings.Join(strategy, ","))
|
|
||||||
|
|
||||||
req, err := am.createRequest(ctx, http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
|
|
||||||
if err != nil {
|
|
||||||
return connections, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := am.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("execute get connections request: %v", err)
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
return connections, err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("close get connections request body: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return connections, fmt.Errorf("unable to get connections, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("Couldn't read get connections response; %v", err)
|
|
||||||
return connections, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = am.helper.Unmarshal(body, &connections)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Debugf("Couldn't unmarshal get connection response; %v", err)
|
|
||||||
return connections, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return connections, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
|
||||||
// If the status is "completed", then return the downloadLink
|
|
||||||
func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string) (bool, string, error) {
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 90*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
retry := time.NewTicker(10 * time.Second)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
log.WithContext(ctx).Debugf("Export job status stopped...\n")
|
|
||||||
return false, "", ctx.Err()
|
|
||||||
case <-retry.C:
|
|
||||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return false, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
statusURL := am.authIssuer + "/api/v2/jobs/" + jobID
|
|
||||||
body, err := doGetReq(ctx, am.httpClient, statusURL, jwtToken.AccessToken)
|
|
||||||
if err != nil {
|
|
||||||
return false, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
var status userExportJobStatusResponse
|
|
||||||
err = am.helper.Unmarshal(body, &status)
|
|
||||||
if err != nil {
|
|
||||||
return false, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("current export job status is %v", status.Status)
|
|
||||||
|
|
||||||
if status.Status != "completed" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, status.Location, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// downloadProfileExport downloads user profiles from auth0 batch job
|
|
||||||
func (am *Auth0Manager) downloadProfileExport(ctx context.Context, location string) (map[string][]*UserData, error) {
|
|
||||||
body, err := doGetReq(ctx, am.httpClient, location, "")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyReader := bytes.NewReader(body)
|
|
||||||
|
|
||||||
gzipReader, err := gzip.NewReader(bodyReader)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
decoder := json.NewDecoder(gzipReader)
|
|
||||||
|
|
||||||
res := make(map[string][]*UserData)
|
|
||||||
for decoder.More() {
|
|
||||||
profile := auth0Profile{}
|
|
||||||
err = decoder.Decode(&profile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if profile.AccountID != "" {
|
|
||||||
if _, ok := res[profile.AccountID]; !ok {
|
|
||||||
res[profile.AccountID] = []*UserData{}
|
|
||||||
}
|
|
||||||
res[profile.AccountID] = append(res[profile.AccountID],
|
|
||||||
&UserData{
|
|
||||||
ID: profile.UserID,
|
|
||||||
Name: profile.Name,
|
|
||||||
Email: profile.Email,
|
|
||||||
AppMetadata: AppMetadata{
|
|
||||||
WTAccountID: profile.AccountID,
|
|
||||||
WTPendingInvite: &profile.PendingInvite,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Boilerplate implementation for Get Requests.
|
|
||||||
func doGetReq(ctx context.Context, client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
|
|
||||||
req, err := http.NewRequest("GET", url, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if accessToken != "" {
|
|
||||||
req.Header.Add("authorization", "Bearer "+accessToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = res.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("error while closing body for url %s: %v", url, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
|
|
||||||
}
|
|
||||||
return body, nil
|
|
||||||
}
|
|
||||||
@@ -1,473 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockHTTPClient struct {
|
|
||||||
code int
|
|
||||||
resBody string
|
|
||||||
reqBody string
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
|
||||||
if req.Body != nil {
|
|
||||||
body, err := io.ReadAll(req.Body)
|
|
||||||
if err == nil {
|
|
||||||
c.reqBody = string(body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: c.code,
|
|
||||||
Body: io.NopCloser(strings.NewReader(c.resBody)),
|
|
||||||
}, c.err
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockJsonParser struct {
|
|
||||||
jsonParser JsonParser
|
|
||||||
marshalErrorString string
|
|
||||||
unmarshalErrorString string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockJsonParser) Marshal(v interface{}) ([]byte, error) {
|
|
||||||
if m.marshalErrorString != "" {
|
|
||||||
return nil, errors.New(m.marshalErrorString)
|
|
||||||
}
|
|
||||||
return m.jsonParser.Marshal(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockJsonParser) Unmarshal(data []byte, v interface{}) error {
|
|
||||||
if m.unmarshalErrorString != "" {
|
|
||||||
return errors.New(m.unmarshalErrorString)
|
|
||||||
}
|
|
||||||
return m.jsonParser.Unmarshal(data, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockAuth0Credentials struct {
|
|
||||||
jwtToken JWTToken
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *mockAuth0Credentials) Authenticate(_ context.Context) (JWTToken, error) {
|
|
||||||
return mc.jwtToken, mc.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestJWT(t *testing.T, expInt int) string {
|
|
||||||
t.Helper()
|
|
||||||
now := time.Now()
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
||||||
"iat": now.Unix(),
|
|
||||||
"exp": now.Add(time.Duration(expInt) * time.Second).Unix(),
|
|
||||||
})
|
|
||||||
var hmacSampleSecret []byte
|
|
||||||
tokenString, err := token.SignedString(hmacSampleSecret)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return tokenString
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuth0_RequestJWTToken(t *testing.T) {
|
|
||||||
|
|
||||||
type requestJWTTokenTest struct {
|
|
||||||
name string
|
|
||||||
inputCode int
|
|
||||||
inputResBody string
|
|
||||||
helper ManagerHelper
|
|
||||||
expectedFuncExitErrDiff error
|
|
||||||
expectedCode int
|
|
||||||
expectedToken string
|
|
||||||
}
|
|
||||||
exp := 5
|
|
||||||
token := newTestJWT(t, exp)
|
|
||||||
|
|
||||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
|
||||||
name: "Get Good JWT Response",
|
|
||||||
inputCode: 200,
|
|
||||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: token,
|
|
||||||
}
|
|
||||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
|
||||||
name: "Request Bad Status Code",
|
|
||||||
inputCode: 400,
|
|
||||||
inputResBody: "{}",
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
|
|
||||||
jwtReqClient := mockHTTPClient{
|
|
||||||
resBody: testCase.inputResBody,
|
|
||||||
code: testCase.inputCode,
|
|
||||||
}
|
|
||||||
config := Auth0ClientConfig{}
|
|
||||||
|
|
||||||
creds := Auth0Credentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: &jwtReqClient,
|
|
||||||
helper: testCase.helper,
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := creds.requestJWTToken(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
if testCase.expectedFuncExitErrDiff != nil {
|
|
||||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
|
||||||
} else {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
assert.NoError(t, err, "unable to read the response body")
|
|
||||||
|
|
||||||
jwtToken := JWTToken{}
|
|
||||||
err = json.Unmarshal(body, &jwtToken)
|
|
||||||
assert.NoError(t, err, "unable to parse the json input")
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuth0_ParseRequestJWTResponse(t *testing.T) {
|
|
||||||
type parseRequestJWTResponseTest struct {
|
|
||||||
name string
|
|
||||||
inputResBody string
|
|
||||||
helper ManagerHelper
|
|
||||||
expectedToken string
|
|
||||||
expectedExpiresIn int
|
|
||||||
assertErrFunc func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool
|
|
||||||
assertErrFuncMessage string
|
|
||||||
}
|
|
||||||
|
|
||||||
exp := 100
|
|
||||||
token := newTestJWT(t, exp)
|
|
||||||
|
|
||||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
|
||||||
name: "Parse Good JWT Body",
|
|
||||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedToken: token,
|
|
||||||
expectedExpiresIn: exp,
|
|
||||||
assertErrFunc: assert.NoError,
|
|
||||||
assertErrFuncMessage: "no error was expected",
|
|
||||||
}
|
|
||||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
|
||||||
name: "Parse Bad json JWT Body",
|
|
||||||
inputResBody: "",
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedToken: "",
|
|
||||||
expectedExpiresIn: 0,
|
|
||||||
assertErrFunc: assert.Error,
|
|
||||||
assertErrFuncMessage: "json error was expected",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
|
|
||||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputResBody))
|
|
||||||
|
|
||||||
config := Auth0ClientConfig{}
|
|
||||||
|
|
||||||
creds := Auth0Credentials{
|
|
||||||
clientConfig: config,
|
|
||||||
helper: testCase.helper,
|
|
||||||
}
|
|
||||||
|
|
||||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
|
||||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuth0_JwtStillValid(t *testing.T) {
|
|
||||||
|
|
||||||
type jwtStillValidTest struct {
|
|
||||||
name string
|
|
||||||
inputTime time.Time
|
|
||||||
expectedResult bool
|
|
||||||
message string
|
|
||||||
}
|
|
||||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
|
||||||
name: "JWT still valid",
|
|
||||||
inputTime: time.Now().Add(10 * time.Second),
|
|
||||||
expectedResult: true,
|
|
||||||
message: "should be true",
|
|
||||||
}
|
|
||||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
|
||||||
name: "JWT is invalid",
|
|
||||||
inputTime: time.Now(),
|
|
||||||
expectedResult: false,
|
|
||||||
message: "should be false",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
|
|
||||||
config := Auth0ClientConfig{}
|
|
||||||
|
|
||||||
creds := Auth0Credentials{
|
|
||||||
clientConfig: config,
|
|
||||||
}
|
|
||||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuth0_Authenticate(t *testing.T) {
|
|
||||||
type authenticateTest struct {
|
|
||||||
name string
|
|
||||||
inputCode int
|
|
||||||
inputResBody string
|
|
||||||
inputExpireToken time.Time
|
|
||||||
helper ManagerHelper
|
|
||||||
expectedFuncExitErrDiff error
|
|
||||||
expectedCode int
|
|
||||||
expectedToken string
|
|
||||||
}
|
|
||||||
exp := 5
|
|
||||||
token := newTestJWT(t, exp)
|
|
||||||
|
|
||||||
authenticateTestCase1 := authenticateTest{
|
|
||||||
name: "Get Cached token",
|
|
||||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
|
||||||
helper: JsonParser{},
|
|
||||||
// expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
authenticateTestCase2 := authenticateTest{
|
|
||||||
name: "Get Good JWT Response",
|
|
||||||
inputCode: 200,
|
|
||||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: token,
|
|
||||||
}
|
|
||||||
authenticateTestCase3 := authenticateTest{
|
|
||||||
name: "Get Bad Status Code",
|
|
||||||
inputCode: 400,
|
|
||||||
inputResBody: "{}",
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
|
|
||||||
jwtReqClient := mockHTTPClient{
|
|
||||||
resBody: testCase.inputResBody,
|
|
||||||
code: testCase.inputCode,
|
|
||||||
}
|
|
||||||
config := Auth0ClientConfig{}
|
|
||||||
|
|
||||||
creds := Auth0Credentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: &jwtReqClient,
|
|
||||||
helper: testCase.helper,
|
|
||||||
}
|
|
||||||
|
|
||||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
|
||||||
|
|
||||||
_, err := creds.Authenticate(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
if testCase.expectedFuncExitErrDiff != nil {
|
|
||||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
|
||||||
} else {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
|
|
||||||
|
|
||||||
type updateUserAppMetadataTest struct {
|
|
||||||
name string
|
|
||||||
inputReqBody string
|
|
||||||
expectedReqBody string
|
|
||||||
appMetadata AppMetadata
|
|
||||||
statusCode int
|
|
||||||
helper ManagerHelper
|
|
||||||
managerCreds ManagerCredentials
|
|
||||||
assertErrFunc func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool
|
|
||||||
assertErrFuncMessage string
|
|
||||||
}
|
|
||||||
|
|
||||||
exp := 15
|
|
||||||
token := newTestJWT(t, exp)
|
|
||||||
appMetadata := AppMetadata{WTAccountID: "ok"}
|
|
||||||
|
|
||||||
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
|
||||||
name: "Bad Authentication",
|
|
||||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
expectedReqBody: "",
|
|
||||||
appMetadata: appMetadata,
|
|
||||||
statusCode: 400,
|
|
||||||
helper: JsonParser{},
|
|
||||||
managerCreds: &mockAuth0Credentials{
|
|
||||||
jwtToken: JWTToken{},
|
|
||||||
err: fmt.Errorf("error"),
|
|
||||||
},
|
|
||||||
assertErrFunc: assert.Error,
|
|
||||||
assertErrFuncMessage: "should return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
|
||||||
name: "Bad Status Code",
|
|
||||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountID),
|
|
||||||
appMetadata: appMetadata,
|
|
||||||
statusCode: 400,
|
|
||||||
helper: JsonParser{},
|
|
||||||
managerCreds: &mockAuth0Credentials{
|
|
||||||
jwtToken: JWTToken{},
|
|
||||||
},
|
|
||||||
assertErrFunc: assert.Error,
|
|
||||||
assertErrFuncMessage: "should return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
|
||||||
name: "Bad Response Parsing",
|
|
||||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
statusCode: 400,
|
|
||||||
helper: &mockJsonParser{marshalErrorString: "error"},
|
|
||||||
assertErrFunc: assert.Error,
|
|
||||||
assertErrFuncMessage: "should return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
|
||||||
name: "Good request",
|
|
||||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountID),
|
|
||||||
appMetadata: appMetadata,
|
|
||||||
statusCode: 200,
|
|
||||||
helper: JsonParser{},
|
|
||||||
assertErrFunc: assert.NoError,
|
|
||||||
assertErrFuncMessage: "shouldn't return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
invite := true
|
|
||||||
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
|
|
||||||
name: "Update Pending Invite",
|
|
||||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":true}}", appMetadata.WTAccountID),
|
|
||||||
appMetadata: AppMetadata{
|
|
||||||
WTAccountID: "ok",
|
|
||||||
WTPendingInvite: &invite,
|
|
||||||
},
|
|
||||||
statusCode: 200,
|
|
||||||
helper: JsonParser{},
|
|
||||||
assertErrFunc: assert.NoError,
|
|
||||||
assertErrFuncMessage: "shouldn't return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
|
||||||
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
jwtReqClient := mockHTTPClient{
|
|
||||||
resBody: testCase.inputReqBody,
|
|
||||||
code: testCase.statusCode,
|
|
||||||
}
|
|
||||||
config := Auth0ClientConfig{}
|
|
||||||
|
|
||||||
var creds ManagerCredentials
|
|
||||||
if testCase.managerCreds != nil {
|
|
||||||
creds = testCase.managerCreds
|
|
||||||
} else {
|
|
||||||
creds = &Auth0Credentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: &jwtReqClient,
|
|
||||||
helper: testCase.helper,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := Auth0Manager{
|
|
||||||
httpClient: &jwtReqClient,
|
|
||||||
credentials: creds,
|
|
||||||
helper: testCase.helper,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := manager.UpdateUserAppMetadata(context.Background(), "1", testCase.appMetadata)
|
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
|
||||||
|
|
||||||
assert.Equal(t, testCase.expectedReqBody, jwtReqClient.reqBody, "request body should match")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewAuth0Manager(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
name string
|
|
||||||
inputConfig Auth0ClientConfig
|
|
||||||
assertErrFunc require.ErrorAssertionFunc
|
|
||||||
assertErrFuncMessage string
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultTestConfig := Auth0ClientConfig{
|
|
||||||
AuthIssuer: "https://abc-auth0.eu.auth0.com",
|
|
||||||
Audience: "https://abc-auth0.eu.auth0.com/api/v2/",
|
|
||||||
ClientID: "abcdefg",
|
|
||||||
ClientSecret: "supersecret",
|
|
||||||
GrantType: "client_credentials",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase1 := test{
|
|
||||||
name: "Good Scenario With Config",
|
|
||||||
inputConfig: defaultTestConfig,
|
|
||||||
assertErrFunc: require.NoError,
|
|
||||||
assertErrFuncMessage: "shouldn't return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase2Config := defaultTestConfig
|
|
||||||
testCase2Config.ClientID = ""
|
|
||||||
|
|
||||||
testCase2 := test{
|
|
||||||
name: "Missing Configuration",
|
|
||||||
inputConfig: testCase2Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "shouldn't return error when field empty",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase3Config := defaultTestConfig
|
|
||||||
testCase3Config.AuthIssuer = "abc-auth0.eu.auth0.com"
|
|
||||||
|
|
||||||
for _, testCase := range []test{testCase1, testCase2} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
_, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,428 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"goauthentik.io/api/v3"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AuthentikManager authentik manager client instance.
|
|
||||||
type AuthentikManager struct {
|
|
||||||
apiClient *api.APIClient
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
credentials ManagerCredentials
|
|
||||||
helper ManagerHelper
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthentikClientConfig authentik manager client configurations.
|
|
||||||
type AuthentikClientConfig struct {
|
|
||||||
Issuer string
|
|
||||||
ClientID string
|
|
||||||
Username string
|
|
||||||
Password string
|
|
||||||
TokenEndpoint string
|
|
||||||
GrantType string
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthentikCredentials authentik authentication information.
|
|
||||||
type AuthentikCredentials struct {
|
|
||||||
clientConfig AuthentikClientConfig
|
|
||||||
helper ManagerHelper
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
jwtToken JWTToken
|
|
||||||
mux sync.Mutex
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAuthentikManager creates a new instance of the AuthentikManager.
|
|
||||||
func NewAuthentikManager(config AuthentikClientConfig,
|
|
||||||
appMetrics telemetry.AppMetrics) (*AuthentikManager, error) {
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
httpTransport.MaxIdleConns = 5
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
Transport: httpTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
helper := JsonParser{}
|
|
||||||
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, clientID is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Username == "" {
|
|
||||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Username is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Password == "" {
|
|
||||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Password is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, TokenEndpoint is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Issuer == "" {
|
|
||||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Issuer is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.GrantType == "" {
|
|
||||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, GrantType is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
// authentik client configuration
|
|
||||||
issuerURL, err := url.Parse(config.Issuer)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
authentikConfig := api.NewConfiguration()
|
|
||||||
authentikConfig.HTTPClient = httpClient
|
|
||||||
authentikConfig.Host = issuerURL.Host
|
|
||||||
authentikConfig.Scheme = issuerURL.Scheme
|
|
||||||
|
|
||||||
credentials := &AuthentikCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: httpClient,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &AuthentikManager{
|
|
||||||
apiClient: api.NewAPIClient(authentikConfig),
|
|
||||||
httpClient: httpClient,
|
|
||||||
credentials: credentials,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from authentik.
|
|
||||||
func (ac *AuthentikCredentials) jwtStillValid() bool {
|
|
||||||
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
|
|
||||||
}
|
|
||||||
|
|
||||||
// requestJWTToken performs request to get jwt token.
|
|
||||||
func (ac *AuthentikCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
|
||||||
data := url.Values{}
|
|
||||||
data.Set("client_id", ac.clientConfig.ClientID)
|
|
||||||
data.Set("username", ac.clientConfig.Username)
|
|
||||||
data.Set("password", ac.clientConfig.Password)
|
|
||||||
data.Set("grant_type", ac.clientConfig.GrantType)
|
|
||||||
data.Set("scope", "goauthentik.io/api")
|
|
||||||
|
|
||||||
payload := strings.NewReader(data.Encode())
|
|
||||||
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debug("requesting new jwt token for authentik idp manager")
|
|
||||||
|
|
||||||
resp, err := ac.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if ac.appMetrics != nil {
|
|
||||||
ac.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("unable to get authentik token, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
|
||||||
func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
|
||||||
jwtToken := JWTToken{}
|
|
||||||
body, err := io.ReadAll(rawBody)
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = ac.helper.Unmarshal(body, &jwtToken)
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
|
||||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exp maps into exp from jwt token
|
|
||||||
var IssuedAt struct{ Exp int64 }
|
|
||||||
err = ac.helper.Unmarshal(data, &IssuedAt)
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
|
||||||
|
|
||||||
return jwtToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authenticate retrieves access token to use the authentik management API.
|
|
||||||
func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
|
||||||
ac.mux.Lock()
|
|
||||||
defer ac.mux.Unlock()
|
|
||||||
|
|
||||||
if ac.appMetrics != nil {
|
|
||||||
ac.appMetrics.IDPMetrics().CountAuthenticate()
|
|
||||||
}
|
|
||||||
|
|
||||||
// reuse the token without requesting a new one if it is not expired,
|
|
||||||
// and if expiry time is sufficient time available to make a request.
|
|
||||||
if ac.jwtStillValid() {
|
|
||||||
return ac.jwtToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := ac.requestJWTToken(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return ac.jwtToken, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return ac.jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ac.jwtToken = jwtToken
|
|
||||||
|
|
||||||
return ac.jwtToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
|
||||||
func (am *AuthentikManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserDataByID requests user data from authentik via ID.
|
|
||||||
func (am *AuthentikManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
|
||||||
ctx, err := am.authenticationContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
userPk, err := strconv.ParseInt(userID, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, resp, err := am.apiClient.CoreApi.CoreUsersRetrieve(ctx, int32(userPk)).Execute()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
userData := parseAuthentikUser(*user)
|
|
||||||
userData.AppMetadata = appMetadata
|
|
||||||
|
|
||||||
return userData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccount returns all the users for a given profile.
|
|
||||||
func (am *AuthentikManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
|
||||||
users, err := am.getAllUsers(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountGetAccount()
|
|
||||||
}
|
|
||||||
|
|
||||||
for index, user := range users {
|
|
||||||
user.AppMetadata.WTAccountID = accountID
|
|
||||||
users[index] = user
|
|
||||||
}
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
|
||||||
// It returns a list of users indexed by accountID.
|
|
||||||
func (am *AuthentikManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
|
||||||
users, err := am.getAllUsers(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
indexedUsers := make(map[string][]*UserData)
|
|
||||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountGetAllAccounts()
|
|
||||||
}
|
|
||||||
|
|
||||||
return indexedUsers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getAllUsers returns all users in a Authentik account.
|
|
||||||
func (am *AuthentikManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
|
||||||
users := make([]*UserData, 0)
|
|
||||||
|
|
||||||
page := int32(1)
|
|
||||||
for {
|
|
||||||
ctx, err := am.authenticationContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Page(page).Execute()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, user := range userList.Results {
|
|
||||||
users = append(users, parseAuthentikUser(user))
|
|
||||||
}
|
|
||||||
|
|
||||||
page = int32(userList.GetPagination().Next)
|
|
||||||
if userList.GetPagination().Next == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateUser creates a new user in authentik Idp and sends an invitation.
|
|
||||||
func (am *AuthentikManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
|
||||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserByEmail searches users with a given email.
|
|
||||||
// If no users have been found, this function returns an empty list.
|
|
||||||
func (am *AuthentikManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
|
||||||
ctx, err := am.authenticationContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Email(email).Execute()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
users := make([]*UserData, 0)
|
|
||||||
for _, user := range userList.Results {
|
|
||||||
users = append(users, parseAuthentikUser(user))
|
|
||||||
}
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InviteUserByID resend invitations to users who haven't activated,
|
|
||||||
// their accounts prior to the expiration period.
|
|
||||||
func (am *AuthentikManager) InviteUserByID(_ context.Context, _ string) error {
|
|
||||||
return fmt.Errorf("method InviteUserByID not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteUser from Authentik
|
|
||||||
func (am *AuthentikManager) DeleteUser(ctx context.Context, userID string) error {
|
|
||||||
ctx, err := am.authenticationContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
userPk, err := strconv.ParseInt(userID, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := am.apiClient.CoreApi.CoreUsersDestroy(ctx, int32(userPk)).Execute()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close() // nolint
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountDeleteUser()
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusNoContent {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return fmt.Errorf("unable to delete user %s, statusCode %d", userID, resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *AuthentikManager) authenticationContext(ctx context.Context) (context.Context, error) {
|
|
||||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
value := map[string]api.APIKey{
|
|
||||||
"authentik": {
|
|
||||||
Key: jwtToken.AccessToken,
|
|
||||||
Prefix: jwtToken.TokenType,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseAuthentikUser(user api.User) *UserData {
|
|
||||||
return &UserData{
|
|
||||||
Email: *user.Email,
|
|
||||||
Name: user.Name,
|
|
||||||
ID: strconv.FormatInt(int64(user.Pk), 10),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,320 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewAuthentikManager(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
name string
|
|
||||||
inputConfig AuthentikClientConfig
|
|
||||||
assertErrFunc require.ErrorAssertionFunc
|
|
||||||
assertErrFuncMessage string
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultTestConfig := AuthentikClientConfig{
|
|
||||||
ClientID: "client_id",
|
|
||||||
Username: "username",
|
|
||||||
Password: "password",
|
|
||||||
TokenEndpoint: "https://localhost:8080/application/o/token/",
|
|
||||||
Issuer: "https://localhost:8080/application/o/netbird/",
|
|
||||||
GrantType: "client_credentials",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase1 := test{
|
|
||||||
name: "Good Configuration",
|
|
||||||
inputConfig: defaultTestConfig,
|
|
||||||
assertErrFunc: require.NoError,
|
|
||||||
assertErrFuncMessage: "shouldn't return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase2Config := defaultTestConfig
|
|
||||||
testCase2Config.ClientID = ""
|
|
||||||
|
|
||||||
testCase2 := test{
|
|
||||||
name: "Missing ClientID Configuration",
|
|
||||||
inputConfig: testCase2Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase3Config := defaultTestConfig
|
|
||||||
testCase3Config.Username = ""
|
|
||||||
|
|
||||||
testCase3 := test{
|
|
||||||
name: "Missing Username Configuration",
|
|
||||||
inputConfig: testCase3Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase4Config := defaultTestConfig
|
|
||||||
testCase4Config.Password = ""
|
|
||||||
|
|
||||||
testCase4 := test{
|
|
||||||
name: "Missing Password Configuration",
|
|
||||||
inputConfig: testCase4Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase5Config := defaultTestConfig
|
|
||||||
testCase5Config.GrantType = ""
|
|
||||||
|
|
||||||
testCase5 := test{
|
|
||||||
name: "Missing GrantType Configuration",
|
|
||||||
inputConfig: testCase5Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase6Config := defaultTestConfig
|
|
||||||
testCase6Config.Issuer = ""
|
|
||||||
|
|
||||||
testCase6 := test{
|
|
||||||
name: "Missing Issuer Configuration",
|
|
||||||
inputConfig: testCase6Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
_, err := NewAuthentikManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthentikRequestJWTToken(t *testing.T) {
|
|
||||||
type requestJWTTokenTest struct {
|
|
||||||
name string
|
|
||||||
inputCode int
|
|
||||||
inputRespBody string
|
|
||||||
helper ManagerHelper
|
|
||||||
expectedFuncExitErrDiff error
|
|
||||||
expectedToken string
|
|
||||||
}
|
|
||||||
exp := 5
|
|
||||||
token := newTestJWT(t, exp)
|
|
||||||
|
|
||||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
|
||||||
name: "Good JWT Response",
|
|
||||||
inputCode: 200,
|
|
||||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedToken: token,
|
|
||||||
}
|
|
||||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
|
||||||
name: "Request Bad Status Code",
|
|
||||||
inputCode: 400,
|
|
||||||
inputRespBody: "{}",
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get authentik token, statusCode 400"),
|
|
||||||
expectedToken: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
|
|
||||||
jwtReqClient := mockHTTPClient{
|
|
||||||
resBody: testCase.inputRespBody,
|
|
||||||
code: testCase.inputCode,
|
|
||||||
}
|
|
||||||
config := AuthentikClientConfig{}
|
|
||||||
|
|
||||||
creds := AuthentikCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: &jwtReqClient,
|
|
||||||
helper: testCase.helper,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := creds.requestJWTToken(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
if testCase.expectedFuncExitErrDiff != nil {
|
|
||||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
|
||||||
} else {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
defer resp.Body.Close()
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
assert.NoError(t, err, "unable to read the response body")
|
|
||||||
|
|
||||||
jwtToken := JWTToken{}
|
|
||||||
err = testCase.helper.Unmarshal(body, &jwtToken)
|
|
||||||
assert.NoError(t, err, "unable to parse the json input")
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthentikParseRequestJWTResponse(t *testing.T) {
|
|
||||||
type parseRequestJWTResponseTest struct {
|
|
||||||
name string
|
|
||||||
inputRespBody string
|
|
||||||
helper ManagerHelper
|
|
||||||
expectedToken string
|
|
||||||
expectedExpiresIn int
|
|
||||||
assertErrFunc assert.ErrorAssertionFunc
|
|
||||||
assertErrFuncMessage string
|
|
||||||
}
|
|
||||||
|
|
||||||
exp := 100
|
|
||||||
token := newTestJWT(t, exp)
|
|
||||||
|
|
||||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
|
||||||
name: "Parse Good JWT Body",
|
|
||||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedToken: token,
|
|
||||||
expectedExpiresIn: exp,
|
|
||||||
assertErrFunc: assert.NoError,
|
|
||||||
assertErrFuncMessage: "no error was expected",
|
|
||||||
}
|
|
||||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
|
||||||
name: "Parse Bad json JWT Body",
|
|
||||||
inputRespBody: "",
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedToken: "",
|
|
||||||
expectedExpiresIn: 0,
|
|
||||||
assertErrFunc: assert.Error,
|
|
||||||
assertErrFuncMessage: "json error was expected",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
|
|
||||||
config := AuthentikClientConfig{}
|
|
||||||
|
|
||||||
creds := AuthentikCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
helper: testCase.helper,
|
|
||||||
}
|
|
||||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
|
||||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthentikJwtStillValid(t *testing.T) {
|
|
||||||
type jwtStillValidTest struct {
|
|
||||||
name string
|
|
||||||
inputTime time.Time
|
|
||||||
expectedResult bool
|
|
||||||
message string
|
|
||||||
}
|
|
||||||
|
|
||||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
|
||||||
name: "JWT still valid",
|
|
||||||
inputTime: time.Now().Add(10 * time.Second),
|
|
||||||
expectedResult: true,
|
|
||||||
message: "should be true",
|
|
||||||
}
|
|
||||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
|
||||||
name: "JWT is invalid",
|
|
||||||
inputTime: time.Now(),
|
|
||||||
expectedResult: false,
|
|
||||||
message: "should be false",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
config := AuthentikClientConfig{}
|
|
||||||
|
|
||||||
creds := AuthentikCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
}
|
|
||||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthentikAuthenticate(t *testing.T) {
|
|
||||||
type authenticateTest struct {
|
|
||||||
name string
|
|
||||||
inputCode int
|
|
||||||
inputResBody string
|
|
||||||
inputExpireToken time.Time
|
|
||||||
helper ManagerHelper
|
|
||||||
expectedFuncExitErrDiff error
|
|
||||||
expectedCode int
|
|
||||||
expectedToken string
|
|
||||||
}
|
|
||||||
exp := 5
|
|
||||||
token := newTestJWT(t, exp)
|
|
||||||
|
|
||||||
authenticateTestCase1 := authenticateTest{
|
|
||||||
name: "Get Cached token",
|
|
||||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedFuncExitErrDiff: nil,
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
authenticateTestCase2 := authenticateTest{
|
|
||||||
name: "Get Good JWT Response",
|
|
||||||
inputCode: 200,
|
|
||||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: token,
|
|
||||||
}
|
|
||||||
|
|
||||||
authenticateTestCase3 := authenticateTest{
|
|
||||||
name: "Get Bad Status Code",
|
|
||||||
inputCode: 400,
|
|
||||||
inputResBody: "{}",
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get authentik token, statusCode 400"),
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
|
|
||||||
jwtReqClient := mockHTTPClient{
|
|
||||||
resBody: testCase.inputResBody,
|
|
||||||
code: testCase.inputCode,
|
|
||||||
}
|
|
||||||
config := AuthentikClientConfig{}
|
|
||||||
|
|
||||||
creds := AuthentikCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: &jwtReqClient,
|
|
||||||
helper: testCase.helper,
|
|
||||||
}
|
|
||||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
|
||||||
|
|
||||||
_, err := creds.Authenticate(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
if testCase.expectedFuncExitErrDiff != nil {
|
|
||||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
|
||||||
} else {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,459 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
)
|
|
||||||
|
|
||||||
const profileFields = "id,displayName,mail,userPrincipalName"
|
|
||||||
|
|
||||||
// AzureManager azure manager client instance.
|
|
||||||
type AzureManager struct {
|
|
||||||
ClientID string
|
|
||||||
ObjectID string
|
|
||||||
GraphAPIEndpoint string
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
credentials ManagerCredentials
|
|
||||||
helper ManagerHelper
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// AzureClientConfig azure manager client configurations.
|
|
||||||
type AzureClientConfig struct {
|
|
||||||
ClientID string
|
|
||||||
ClientSecret string
|
|
||||||
ObjectID string
|
|
||||||
GraphAPIEndpoint string
|
|
||||||
TokenEndpoint string
|
|
||||||
GrantType string
|
|
||||||
}
|
|
||||||
|
|
||||||
// AzureCredentials azure authentication information.
|
|
||||||
type AzureCredentials struct {
|
|
||||||
clientConfig AzureClientConfig
|
|
||||||
helper ManagerHelper
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
jwtToken JWTToken
|
|
||||||
mux sync.Mutex
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// azureProfile represents an azure user profile.
|
|
||||||
type azureProfile map[string]any
|
|
||||||
|
|
||||||
// NewAzureManager creates a new instance of the AzureManager.
|
|
||||||
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
httpTransport.MaxIdleConns = 5
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
Transport: httpTransport,
|
|
||||||
}
|
|
||||||
helper := JsonParser{}
|
|
||||||
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, clientID is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.ClientSecret == "" {
|
|
||||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, ClientSecret is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, TokenEndpoint is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.GraphAPIEndpoint == "" {
|
|
||||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, GraphAPIEndpoint is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.ObjectID == "" {
|
|
||||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, ObjectID is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.GrantType == "" {
|
|
||||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, GrantType is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
credentials := &AzureCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: httpClient,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &AzureManager{
|
|
||||||
ObjectID: config.ObjectID,
|
|
||||||
ClientID: config.ClientID,
|
|
||||||
GraphAPIEndpoint: config.GraphAPIEndpoint,
|
|
||||||
httpClient: httpClient,
|
|
||||||
credentials: credentials,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
|
|
||||||
func (ac *AzureCredentials) jwtStillValid() bool {
|
|
||||||
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
|
|
||||||
}
|
|
||||||
|
|
||||||
// requestJWTToken performs request to get jwt token.
|
|
||||||
func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
|
||||||
data := url.Values{}
|
|
||||||
data.Set("client_id", ac.clientConfig.ClientID)
|
|
||||||
data.Set("client_secret", ac.clientConfig.ClientSecret)
|
|
||||||
data.Set("grant_type", ac.clientConfig.GrantType)
|
|
||||||
parsedURL, err := url.Parse(ac.clientConfig.GraphAPIEndpoint)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// get base url and add "/.default" as scope
|
|
||||||
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
|
|
||||||
scopeURL := baseURL + "/.default"
|
|
||||||
data.Set("scope", scopeURL)
|
|
||||||
|
|
||||||
payload := strings.NewReader(data.Encode())
|
|
||||||
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debug("requesting new jwt token for azure idp manager")
|
|
||||||
|
|
||||||
resp, err := ac.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if ac.appMetrics != nil {
|
|
||||||
ac.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("unable to get azure token, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
|
||||||
func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
|
||||||
jwtToken := JWTToken{}
|
|
||||||
body, err := io.ReadAll(rawBody)
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = ac.helper.Unmarshal(body, &jwtToken)
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
|
||||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exp maps into exp from jwt token
|
|
||||||
var IssuedAt struct{ Exp int64 }
|
|
||||||
err = ac.helper.Unmarshal(data, &IssuedAt)
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
|
||||||
|
|
||||||
return jwtToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authenticate retrieves access token to use the azure Management API.
|
|
||||||
func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
|
||||||
ac.mux.Lock()
|
|
||||||
defer ac.mux.Unlock()
|
|
||||||
|
|
||||||
if ac.appMetrics != nil {
|
|
||||||
ac.appMetrics.IDPMetrics().CountAuthenticate()
|
|
||||||
}
|
|
||||||
|
|
||||||
// reuse the token without requesting a new one if it is not expired,
|
|
||||||
// and if expiry time is sufficient time available to make a request.
|
|
||||||
if ac.jwtStillValid() {
|
|
||||||
return ac.jwtToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := ac.requestJWTToken(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return ac.jwtToken, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return ac.jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ac.jwtToken = jwtToken
|
|
||||||
|
|
||||||
return ac.jwtToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateUser creates a new user in azure AD Idp.
|
|
||||||
func (am *AzureManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
|
||||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserDataByID requests user data from keycloak via ID.
|
|
||||||
func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
|
||||||
q := url.Values{}
|
|
||||||
q.Add("$select", profileFields)
|
|
||||||
|
|
||||||
body, err := am.get(ctx, "users/"+userID, q)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
|
||||||
}
|
|
||||||
|
|
||||||
var profile azureProfile
|
|
||||||
err = am.helper.Unmarshal(body, &profile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
userData := profile.userData()
|
|
||||||
userData.AppMetadata = appMetadata
|
|
||||||
|
|
||||||
return userData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserByEmail searches users with a given email.
|
|
||||||
// If no users have been found, this function returns an empty list.
|
|
||||||
func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
|
||||||
q := url.Values{}
|
|
||||||
q.Add("$select", profileFields)
|
|
||||||
|
|
||||||
body, err := am.get(ctx, "users/"+email, q)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
|
||||||
}
|
|
||||||
|
|
||||||
var profile azureProfile
|
|
||||||
err = am.helper.Unmarshal(body, &profile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
users := make([]*UserData, 0)
|
|
||||||
users = append(users, profile.userData())
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccount returns all the users for a given profile.
|
|
||||||
func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
|
||||||
users, err := am.getAllUsers(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountGetAccount()
|
|
||||||
}
|
|
||||||
|
|
||||||
for index, user := range users {
|
|
||||||
user.AppMetadata.WTAccountID = accountID
|
|
||||||
users[index] = user
|
|
||||||
}
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
|
||||||
// It returns a list of users indexed by accountID.
|
|
||||||
func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
|
||||||
users, err := am.getAllUsers(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
indexedUsers := make(map[string][]*UserData)
|
|
||||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountGetAllAccounts()
|
|
||||||
}
|
|
||||||
|
|
||||||
return indexedUsers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateUserAppMetadata updates user app metadata based on userID.
|
|
||||||
func (am *AzureManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InviteUserByID resend invitations to users who haven't activated,
|
|
||||||
// their accounts prior to the expiration period.
|
|
||||||
func (am *AzureManager) InviteUserByID(_ context.Context, _ string) error {
|
|
||||||
return fmt.Errorf("method InviteUserByID not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteUser from Azure.
|
|
||||||
func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error {
|
|
||||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID))
|
|
||||||
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
|
||||||
req.Header.Add("content-type", "application/json")
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("delete idp user %s", userID)
|
|
||||||
|
|
||||||
resp, err := am.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountDeleteUser()
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusNoContent {
|
|
||||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getAllUsers returns all users in an Azure AD account.
|
|
||||||
func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
|
||||||
users := make([]*UserData, 0)
|
|
||||||
|
|
||||||
q := url.Values{}
|
|
||||||
q.Add("$select", profileFields)
|
|
||||||
q.Add("$top", "500")
|
|
||||||
|
|
||||||
for nextLink := "users"; nextLink != ""; {
|
|
||||||
body, err := am.get(ctx, nextLink, q)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var profiles struct {
|
|
||||||
Value []azureProfile
|
|
||||||
NextLink string `json:"@odata.nextLink"`
|
|
||||||
}
|
|
||||||
err = am.helper.Unmarshal(body, &profiles)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, profile := range profiles.Value {
|
|
||||||
users = append(users, profile.userData())
|
|
||||||
}
|
|
||||||
|
|
||||||
nextLink = profiles.NextLink
|
|
||||||
}
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// get perform Get requests.
|
|
||||||
func (am *AzureManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
|
||||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var reqURL string
|
|
||||||
if strings.HasPrefix(resource, "https") {
|
|
||||||
// Already an absolute URL for paging
|
|
||||||
reqURL = resource
|
|
||||||
} else {
|
|
||||||
reqURL = fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
|
||||||
req.Header.Add("content-type", "application/json")
|
|
||||||
|
|
||||||
resp, err := am.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return io.ReadAll(resp.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
// userData construct user data from keycloak profile.
|
|
||||||
func (ap azureProfile) userData() *UserData {
|
|
||||||
id, ok := ap["id"].(string)
|
|
||||||
if !ok {
|
|
||||||
id = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
email, ok := ap["userPrincipalName"].(string)
|
|
||||||
if !ok {
|
|
||||||
email = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
name, ok := ap["displayName"].(string)
|
|
||||||
if !ok {
|
|
||||||
name = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return &UserData{
|
|
||||||
Email: email,
|
|
||||||
Name: name,
|
|
||||||
ID: id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAzureJwtStillValid(t *testing.T) {
|
|
||||||
type jwtStillValidTest struct {
|
|
||||||
name string
|
|
||||||
inputTime time.Time
|
|
||||||
expectedResult bool
|
|
||||||
message string
|
|
||||||
}
|
|
||||||
|
|
||||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
|
||||||
name: "JWT still valid",
|
|
||||||
inputTime: time.Now().Add(10 * time.Second),
|
|
||||||
expectedResult: true,
|
|
||||||
message: "should be true",
|
|
||||||
}
|
|
||||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
|
||||||
name: "JWT is invalid",
|
|
||||||
inputTime: time.Now(),
|
|
||||||
expectedResult: false,
|
|
||||||
message: "should be false",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
config := AzureClientConfig{}
|
|
||||||
|
|
||||||
creds := AzureCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
}
|
|
||||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAzureAuthenticate(t *testing.T) {
|
|
||||||
type authenticateTest struct {
|
|
||||||
name string
|
|
||||||
inputCode int
|
|
||||||
inputResBody string
|
|
||||||
inputExpireToken time.Time
|
|
||||||
helper ManagerHelper
|
|
||||||
expectedFuncExitErrDiff error
|
|
||||||
expectedCode int
|
|
||||||
expectedToken string
|
|
||||||
}
|
|
||||||
exp := 5
|
|
||||||
token := newTestJWT(t, exp)
|
|
||||||
|
|
||||||
authenticateTestCase1 := authenticateTest{
|
|
||||||
name: "Get Cached token",
|
|
||||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedFuncExitErrDiff: nil,
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
authenticateTestCase2 := authenticateTest{
|
|
||||||
name: "Get Good JWT Response",
|
|
||||||
inputCode: 200,
|
|
||||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: token,
|
|
||||||
}
|
|
||||||
|
|
||||||
authenticateTestCase3 := authenticateTest{
|
|
||||||
name: "Get Bad Status Code",
|
|
||||||
inputCode: 400,
|
|
||||||
inputResBody: "{}",
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get azure token, statusCode 400"),
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
|
|
||||||
jwtReqClient := mockHTTPClient{
|
|
||||||
resBody: testCase.inputResBody,
|
|
||||||
code: testCase.inputCode,
|
|
||||||
}
|
|
||||||
config := AzureClientConfig{}
|
|
||||||
|
|
||||||
creds := AzureCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: &jwtReqClient,
|
|
||||||
helper: testCase.helper,
|
|
||||||
}
|
|
||||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
|
||||||
|
|
||||||
_, err := creds.Authenticate(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
if testCase.expectedFuncExitErrDiff != nil {
|
|
||||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
|
||||||
} else {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAzureProfile(t *testing.T) {
|
|
||||||
type azureProfileTest struct {
|
|
||||||
name string
|
|
||||||
invite bool
|
|
||||||
inputProfile azureProfile
|
|
||||||
expectedUserData UserData
|
|
||||||
}
|
|
||||||
|
|
||||||
azureProfileTestCase1 := azureProfileTest{
|
|
||||||
name: "Good Request",
|
|
||||||
invite: false,
|
|
||||||
inputProfile: azureProfile{
|
|
||||||
"id": "test1",
|
|
||||||
"displayName": "John Doe",
|
|
||||||
"userPrincipalName": "test1@test.com",
|
|
||||||
},
|
|
||||||
expectedUserData: UserData{
|
|
||||||
Email: "test1@test.com",
|
|
||||||
Name: "John Doe",
|
|
||||||
ID: "test1",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
azureProfileTestCase2 := azureProfileTest{
|
|
||||||
name: "Missing User ID",
|
|
||||||
invite: true,
|
|
||||||
inputProfile: azureProfile{
|
|
||||||
"displayName": "John Doe",
|
|
||||||
"userPrincipalName": "test2@test.com",
|
|
||||||
},
|
|
||||||
expectedUserData: UserData{
|
|
||||||
Email: "test2@test.com",
|
|
||||||
Name: "John Doe",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
azureProfileTestCase3 := azureProfileTest{
|
|
||||||
name: "Missing User Name",
|
|
||||||
invite: false,
|
|
||||||
inputProfile: azureProfile{
|
|
||||||
"id": "test3",
|
|
||||||
"userPrincipalName": "test3@test.com",
|
|
||||||
},
|
|
||||||
expectedUserData: UserData{
|
|
||||||
ID: "test3",
|
|
||||||
Email: "test3@test.com",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
|
||||||
userData := testCase.inputProfile.userData()
|
|
||||||
|
|
||||||
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
|
||||||
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
|
|
||||||
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
457
management/server/idp/connector.go
Normal file
457
management/server/idp/connector.go
Normal file
@@ -0,0 +1,457 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectorType represents the type of external identity provider connector
|
||||||
|
type ConnectorType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ConnectorTypeOIDC ConnectorType = "oidc"
|
||||||
|
ConnectorTypeLDAP ConnectorType = "ldap"
|
||||||
|
ConnectorTypeSAML ConnectorType = "saml"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Connector represents an external identity provider configured in Zitadel
|
||||||
|
type Connector struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type ConnectorType `json:"type"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Issuer string `json:"issuer,omitempty"` // for OIDC
|
||||||
|
Servers []string `json:"servers,omitempty"` // for LDAP
|
||||||
|
}
|
||||||
|
|
||||||
|
// OIDCConnectorConfig contains configuration for adding an OIDC connector
|
||||||
|
type OIDCConnectorConfig struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
ClientID string `json:"clientId"`
|
||||||
|
ClientSecret string `json:"clientSecret"`
|
||||||
|
Scopes []string `json:"scopes,omitempty"`
|
||||||
|
IsIDTokenMapping bool `json:"isIdTokenMapping,omitempty"`
|
||||||
|
IsAutoCreation bool `json:"isAutoCreation,omitempty"`
|
||||||
|
IsAutoUpdate bool `json:"isAutoUpdate,omitempty"`
|
||||||
|
IsCreationAllowed bool `json:"isCreationAllowed,omitempty"`
|
||||||
|
IsLinkingAllowed bool `json:"isLinkingAllowed,omitempty"`
|
||||||
|
IsAutoAccountLinking bool `json:"isAutoAccountLinking,omitempty"`
|
||||||
|
AccountLinkingEnabled bool `json:"accountLinkingEnabled,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LDAPConnectorConfig contains configuration for adding an LDAP connector
|
||||||
|
type LDAPConnectorConfig struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Servers []string `json:"servers"` // e.g., ["ldap://localhost:389"]
|
||||||
|
StartTLS bool `json:"startTls,omitempty"`
|
||||||
|
BaseDN string `json:"baseDn"`
|
||||||
|
BindDN string `json:"bindDn"`
|
||||||
|
BindPassword string `json:"bindPassword"`
|
||||||
|
UserBase string `json:"userBase,omitempty"` // typically "dn"
|
||||||
|
UserObjectClass []string `json:"userObjectClass,omitempty"` // e.g., ["user", "person"]
|
||||||
|
UserFilters []string `json:"userFilters,omitempty"` // e.g., ["uid", "email"]
|
||||||
|
Timeout string `json:"timeout,omitempty"` // e.g., "10s"
|
||||||
|
Attributes LDAPAttributes `json:"attributes,omitempty"`
|
||||||
|
IsAutoCreation bool `json:"isAutoCreation,omitempty"`
|
||||||
|
IsAutoUpdate bool `json:"isAutoUpdate,omitempty"`
|
||||||
|
IsCreationAllowed bool `json:"isCreationAllowed,omitempty"`
|
||||||
|
IsLinkingAllowed bool `json:"isLinkingAllowed,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LDAPAttributes maps LDAP attributes to Zitadel user fields
|
||||||
|
type LDAPAttributes struct {
|
||||||
|
IDAttribute string `json:"idAttribute,omitempty"`
|
||||||
|
FirstNameAttribute string `json:"firstNameAttribute,omitempty"`
|
||||||
|
LastNameAttribute string `json:"lastNameAttribute,omitempty"`
|
||||||
|
DisplayNameAttribute string `json:"displayNameAttribute,omitempty"`
|
||||||
|
NickNameAttribute string `json:"nickNameAttribute,omitempty"`
|
||||||
|
EmailAttribute string `json:"emailAttribute,omitempty"`
|
||||||
|
EmailVerified string `json:"emailVerified,omitempty"`
|
||||||
|
PhoneAttribute string `json:"phoneAttribute,omitempty"`
|
||||||
|
PhoneVerified string `json:"phoneVerified,omitempty"`
|
||||||
|
AvatarURLAttribute string `json:"avatarUrlAttribute,omitempty"`
|
||||||
|
ProfileAttribute string `json:"profileAttribute,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAMLConnectorConfig contains configuration for adding a SAML connector
|
||||||
|
type SAMLConnectorConfig struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
MetadataXML string `json:"metadataXml,omitempty"`
|
||||||
|
MetadataURL string `json:"metadataUrl,omitempty"`
|
||||||
|
Binding string `json:"binding,omitempty"` // "SAML_BINDING_POST" or "SAML_BINDING_REDIRECT"
|
||||||
|
WithSignedRequest bool `json:"withSignedRequest,omitempty"`
|
||||||
|
NameIDFormat string `json:"nameIdFormat,omitempty"`
|
||||||
|
IsAutoCreation bool `json:"isAutoCreation,omitempty"`
|
||||||
|
IsAutoUpdate bool `json:"isAutoUpdate,omitempty"`
|
||||||
|
IsCreationAllowed bool `json:"isCreationAllowed,omitempty"`
|
||||||
|
IsLinkingAllowed bool `json:"isLinkingAllowed,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectorManager defines the interface for managing external IdP connectors
|
||||||
|
type ConnectorManager interface {
|
||||||
|
// AddOIDCConnector adds a Generic OIDC identity provider connector
|
||||||
|
AddOIDCConnector(ctx context.Context, config OIDCConnectorConfig) (*Connector, error)
|
||||||
|
// AddLDAPConnector adds an LDAP identity provider connector
|
||||||
|
AddLDAPConnector(ctx context.Context, config LDAPConnectorConfig) (*Connector, error)
|
||||||
|
// AddSAMLConnector adds a SAML identity provider connector
|
||||||
|
AddSAMLConnector(ctx context.Context, config SAMLConnectorConfig) (*Connector, error)
|
||||||
|
// ListConnectors returns all configured identity provider connectors
|
||||||
|
ListConnectors(ctx context.Context) ([]*Connector, error)
|
||||||
|
// GetConnector returns a specific connector by ID
|
||||||
|
GetConnector(ctx context.Context, connectorID string) (*Connector, error)
|
||||||
|
// DeleteConnector removes an identity provider connector
|
||||||
|
DeleteConnector(ctx context.Context, connectorID string) error
|
||||||
|
// ActivateConnector adds the connector to the login policy
|
||||||
|
ActivateConnector(ctx context.Context, connectorID string) error
|
||||||
|
// DeactivateConnector removes the connector from the login policy
|
||||||
|
DeactivateConnector(ctx context.Context, connectorID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// zitadelProviderResponse represents the response from creating a provider
|
||||||
|
type zitadelProviderResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Details struct {
|
||||||
|
Sequence string `json:"sequence"`
|
||||||
|
CreationDate string `json:"creationDate"`
|
||||||
|
ChangeDate string `json:"changeDate"`
|
||||||
|
ResourceOwner string `json:"resourceOwner"`
|
||||||
|
} `json:"details"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// zitadelProviderTemplate represents a provider in the list response
|
||||||
|
type zitadelProviderTemplate struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
State string `json:"state"` // IDP_STATE_ACTIVE, IDP_STATE_INACTIVE
|
||||||
|
Type string `json:"type"` // IDP_TYPE_OIDC, IDP_TYPE_LDAP, IDP_TYPE_SAML, etc.
|
||||||
|
Owner string `json:"owner"` // IDP_OWNER_TYPE_ORG, IDP_OWNER_TYPE_SYSTEM
|
||||||
|
// Type-specific fields
|
||||||
|
OIDC *struct {
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
ClientID string `json:"clientId"`
|
||||||
|
} `json:"oidc,omitempty"`
|
||||||
|
LDAP *struct {
|
||||||
|
Servers []string `json:"servers"`
|
||||||
|
BaseDN string `json:"baseDn"`
|
||||||
|
} `json:"ldap,omitempty"`
|
||||||
|
SAML *struct {
|
||||||
|
MetadataURL string `json:"metadataUrl"`
|
||||||
|
} `json:"saml,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddOIDCConnector adds a Generic OIDC identity provider connector to Zitadel
|
||||||
|
func (zm *ZitadelManager) AddOIDCConnector(ctx context.Context, config OIDCConnectorConfig) (*Connector, error) {
|
||||||
|
// Set defaults for creation/linking if not specified
|
||||||
|
if !config.IsCreationAllowed && !config.IsLinkingAllowed {
|
||||||
|
config.IsCreationAllowed = true
|
||||||
|
config.IsLinkingAllowed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"name": config.Name,
|
||||||
|
"issuer": config.Issuer,
|
||||||
|
"clientId": config.ClientID,
|
||||||
|
"clientSecret": config.ClientSecret,
|
||||||
|
"isIdTokenMapping": config.IsIDTokenMapping,
|
||||||
|
"isAutoCreation": config.IsAutoCreation,
|
||||||
|
"isAutoUpdate": config.IsAutoUpdate,
|
||||||
|
"isCreationAllowed": config.IsCreationAllowed,
|
||||||
|
"isLinkingAllowed": config.IsLinkingAllowed,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.Scopes) > 0 {
|
||||||
|
payload["scopes"] = config.Scopes
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := zm.helper.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal OIDC connector config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := zm.post(ctx, "idps/generic_oidc", string(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("add OIDC connector: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp zitadelProviderResponse
|
||||||
|
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal OIDC connector response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Connector{
|
||||||
|
ID: resp.ID,
|
||||||
|
Name: config.Name,
|
||||||
|
Type: ConnectorTypeOIDC,
|
||||||
|
State: "active",
|
||||||
|
Issuer: config.Issuer,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddLDAPConnector adds an LDAP identity provider connector to Zitadel
|
||||||
|
func (zm *ZitadelManager) AddLDAPConnector(ctx context.Context, config LDAPConnectorConfig) (*Connector, error) {
|
||||||
|
// Set defaults
|
||||||
|
if !config.IsCreationAllowed && !config.IsLinkingAllowed {
|
||||||
|
config.IsCreationAllowed = true
|
||||||
|
config.IsLinkingAllowed = true
|
||||||
|
}
|
||||||
|
if config.UserBase == "" {
|
||||||
|
config.UserBase = "dn"
|
||||||
|
}
|
||||||
|
if config.Timeout == "" {
|
||||||
|
config.Timeout = "10s"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"name": config.Name,
|
||||||
|
"servers": config.Servers,
|
||||||
|
"startTls": config.StartTLS,
|
||||||
|
"baseDn": config.BaseDN,
|
||||||
|
"bindDn": config.BindDN,
|
||||||
|
"bindPassword": config.BindPassword,
|
||||||
|
"userBase": config.UserBase,
|
||||||
|
"timeout": config.Timeout,
|
||||||
|
"isAutoCreation": config.IsAutoCreation,
|
||||||
|
"isAutoUpdate": config.IsAutoUpdate,
|
||||||
|
"isCreationAllowed": config.IsCreationAllowed,
|
||||||
|
"isLinkingAllowed": config.IsLinkingAllowed,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.UserObjectClass) > 0 {
|
||||||
|
payload["userObjectClasses"] = config.UserObjectClass
|
||||||
|
}
|
||||||
|
if len(config.UserFilters) > 0 {
|
||||||
|
payload["userFilters"] = config.UserFilters
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add attribute mappings if provided
|
||||||
|
attrs := make(map[string]string)
|
||||||
|
if config.Attributes.IDAttribute != "" {
|
||||||
|
attrs["idAttribute"] = config.Attributes.IDAttribute
|
||||||
|
}
|
||||||
|
if config.Attributes.FirstNameAttribute != "" {
|
||||||
|
attrs["firstNameAttribute"] = config.Attributes.FirstNameAttribute
|
||||||
|
}
|
||||||
|
if config.Attributes.LastNameAttribute != "" {
|
||||||
|
attrs["lastNameAttribute"] = config.Attributes.LastNameAttribute
|
||||||
|
}
|
||||||
|
if config.Attributes.DisplayNameAttribute != "" {
|
||||||
|
attrs["displayNameAttribute"] = config.Attributes.DisplayNameAttribute
|
||||||
|
}
|
||||||
|
if config.Attributes.EmailAttribute != "" {
|
||||||
|
attrs["emailAttribute"] = config.Attributes.EmailAttribute
|
||||||
|
}
|
||||||
|
if len(attrs) > 0 {
|
||||||
|
payload["attributes"] = attrs
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := zm.helper.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal LDAP connector config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := zm.post(ctx, "idps/ldap", string(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("add LDAP connector: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp zitadelProviderResponse
|
||||||
|
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal LDAP connector response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Connector{
|
||||||
|
ID: resp.ID,
|
||||||
|
Name: config.Name,
|
||||||
|
Type: ConnectorTypeLDAP,
|
||||||
|
State: "active",
|
||||||
|
Servers: config.Servers,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSAMLConnector adds a SAML identity provider connector to Zitadel
|
||||||
|
func (zm *ZitadelManager) AddSAMLConnector(ctx context.Context, config SAMLConnectorConfig) (*Connector, error) {
|
||||||
|
// Set defaults
|
||||||
|
if !config.IsCreationAllowed && !config.IsLinkingAllowed {
|
||||||
|
config.IsCreationAllowed = true
|
||||||
|
config.IsLinkingAllowed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"name": config.Name,
|
||||||
|
"isAutoCreation": config.IsAutoCreation,
|
||||||
|
"isAutoUpdate": config.IsAutoUpdate,
|
||||||
|
"isCreationAllowed": config.IsCreationAllowed,
|
||||||
|
"isLinkingAllowed": config.IsLinkingAllowed,
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.MetadataXML != "" {
|
||||||
|
payload["metadataXml"] = config.MetadataXML
|
||||||
|
} else if config.MetadataURL != "" {
|
||||||
|
payload["metadataUrl"] = config.MetadataURL
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("either metadataXml or metadataUrl must be provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Binding != "" {
|
||||||
|
payload["binding"] = config.Binding
|
||||||
|
}
|
||||||
|
if config.WithSignedRequest {
|
||||||
|
payload["withSignedRequest"] = config.WithSignedRequest
|
||||||
|
}
|
||||||
|
if config.NameIDFormat != "" {
|
||||||
|
payload["nameIdFormat"] = config.NameIDFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := zm.helper.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal SAML connector config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := zm.post(ctx, "idps/saml", string(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("add SAML connector: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp zitadelProviderResponse
|
||||||
|
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal SAML connector response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Connector{
|
||||||
|
ID: resp.ID,
|
||||||
|
Name: config.Name,
|
||||||
|
Type: ConnectorTypeSAML,
|
||||||
|
State: "active",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListConnectors returns all configured identity provider connectors
|
||||||
|
func (zm *ZitadelManager) ListConnectors(ctx context.Context) ([]*Connector, error) {
|
||||||
|
// Use the search endpoint to list all providers
|
||||||
|
respBody, err := zm.post(ctx, "idps/_search", "{}")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list connectors: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Result []zitadelProviderTemplate `json:"result"`
|
||||||
|
}
|
||||||
|
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal connectors response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
connectors := make([]*Connector, 0, len(resp.Result))
|
||||||
|
for _, p := range resp.Result {
|
||||||
|
connector := &Connector{
|
||||||
|
ID: p.ID,
|
||||||
|
Name: p.Name,
|
||||||
|
State: normalizeState(p.State),
|
||||||
|
Type: normalizeType(p.Type),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add type-specific fields
|
||||||
|
if p.OIDC != nil {
|
||||||
|
connector.Issuer = p.OIDC.Issuer
|
||||||
|
}
|
||||||
|
if p.LDAP != nil {
|
||||||
|
connector.Servers = p.LDAP.Servers
|
||||||
|
}
|
||||||
|
|
||||||
|
connectors = append(connectors, connector)
|
||||||
|
}
|
||||||
|
|
||||||
|
return connectors, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnector returns a specific connector by ID
|
||||||
|
func (zm *ZitadelManager) GetConnector(ctx context.Context, connectorID string) (*Connector, error) {
|
||||||
|
respBody, err := zm.get(ctx, fmt.Sprintf("idps/%s", connectorID), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get connector: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
IDP zitadelProviderTemplate `json:"idp"`
|
||||||
|
}
|
||||||
|
if err := zm.helper.Unmarshal(respBody, &resp); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal connector response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
connector := &Connector{
|
||||||
|
ID: resp.IDP.ID,
|
||||||
|
Name: resp.IDP.Name,
|
||||||
|
State: normalizeState(resp.IDP.State),
|
||||||
|
Type: normalizeType(resp.IDP.Type),
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.IDP.OIDC != nil {
|
||||||
|
connector.Issuer = resp.IDP.OIDC.Issuer
|
||||||
|
}
|
||||||
|
if resp.IDP.LDAP != nil {
|
||||||
|
connector.Servers = resp.IDP.LDAP.Servers
|
||||||
|
}
|
||||||
|
|
||||||
|
return connector, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteConnector removes an identity provider connector
|
||||||
|
func (zm *ZitadelManager) DeleteConnector(ctx context.Context, connectorID string) error {
|
||||||
|
if err := zm.delete(ctx, fmt.Sprintf("idps/%s", connectorID)); err != nil {
|
||||||
|
return fmt.Errorf("delete connector: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ActivateConnector adds the connector to the organization's login policy
|
||||||
|
func (zm *ZitadelManager) ActivateConnector(ctx context.Context, connectorID string) error {
|
||||||
|
payload := map[string]string{
|
||||||
|
"idpId": connectorID,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := zm.helper.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal activate request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = zm.post(ctx, "policies/login/idps", string(body))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("activate connector: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeactivateConnector removes the connector from the organization's login policy
|
||||||
|
func (zm *ZitadelManager) DeactivateConnector(ctx context.Context, connectorID string) error {
|
||||||
|
if err := zm.delete(ctx, fmt.Sprintf("policies/login/idps/%s", connectorID)); err != nil {
|
||||||
|
return fmt.Errorf("deactivate connector: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeState converts Zitadel state to a simple string
|
||||||
|
func normalizeState(state string) string {
|
||||||
|
switch state {
|
||||||
|
case "IDP_STATE_ACTIVE":
|
||||||
|
return "active"
|
||||||
|
case "IDP_STATE_INACTIVE":
|
||||||
|
return "inactive"
|
||||||
|
default:
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeType converts Zitadel type to ConnectorType
|
||||||
|
func normalizeType(idpType string) ConnectorType {
|
||||||
|
switch idpType {
|
||||||
|
case "IDP_TYPE_OIDC", "IDP_TYPE_OIDC_GENERIC":
|
||||||
|
return ConnectorTypeOIDC
|
||||||
|
case "IDP_TYPE_LDAP":
|
||||||
|
return ConnectorTypeLDAP
|
||||||
|
case "IDP_TYPE_SAML":
|
||||||
|
return ConnectorTypeSAML
|
||||||
|
default:
|
||||||
|
return ConnectorType(idpType)
|
||||||
|
}
|
||||||
|
}
|
||||||
313
management/server/idp/connector_test.go
Normal file
313
management/server/idp/connector_test.go
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestZitadelManager_AddOIDCConnector(t *testing.T) {
|
||||||
|
// Create a mock response for the OIDC connector creation
|
||||||
|
mockResponse := `{"id": "oidc-123", "details": {"sequence": "1", "creationDate": "2024-01-01T00:00:00Z", "changeDate": "2024-01-01T00:00:00Z", "resourceOwner": "org-1"}}`
|
||||||
|
|
||||||
|
mockClient := &mockHTTPClient{
|
||||||
|
code: http.StatusCreated,
|
||||||
|
resBody: mockResponse,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &ZitadelManager{
|
||||||
|
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||||
|
httpClient: mockClient,
|
||||||
|
credentials: &mockCredentials{token: "test-token"},
|
||||||
|
helper: JsonParser{},
|
||||||
|
}
|
||||||
|
|
||||||
|
config := OIDCConnectorConfig{
|
||||||
|
Name: "Okta",
|
||||||
|
Issuer: "https://okta.example.com",
|
||||||
|
ClientID: "client-123",
|
||||||
|
ClientSecret: "secret-456",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
}
|
||||||
|
|
||||||
|
connector, err := manager.AddOIDCConnector(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "oidc-123", connector.ID)
|
||||||
|
assert.Equal(t, "Okta", connector.Name)
|
||||||
|
assert.Equal(t, ConnectorTypeOIDC, connector.Type)
|
||||||
|
assert.Equal(t, "https://okta.example.com", connector.Issuer)
|
||||||
|
|
||||||
|
// Verify the request body contains expected fields
|
||||||
|
var reqBody map[string]any
|
||||||
|
err = json.Unmarshal([]byte(mockClient.reqBody), &reqBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "Okta", reqBody["name"])
|
||||||
|
assert.Equal(t, "https://okta.example.com", reqBody["issuer"])
|
||||||
|
assert.Equal(t, "client-123", reqBody["clientId"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelManager_AddLDAPConnector(t *testing.T) {
|
||||||
|
mockResponse := `{"id": "ldap-456", "details": {"sequence": "1", "creationDate": "2024-01-01T00:00:00Z", "changeDate": "2024-01-01T00:00:00Z", "resourceOwner": "org-1"}}`
|
||||||
|
|
||||||
|
mockClient := &mockHTTPClient{
|
||||||
|
code: http.StatusCreated,
|
||||||
|
resBody: mockResponse,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &ZitadelManager{
|
||||||
|
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||||
|
httpClient: mockClient,
|
||||||
|
credentials: &mockCredentials{token: "test-token"},
|
||||||
|
helper: JsonParser{},
|
||||||
|
}
|
||||||
|
|
||||||
|
config := LDAPConnectorConfig{
|
||||||
|
Name: "Corporate LDAP",
|
||||||
|
Servers: []string{"ldap://ldap.example.com:389"},
|
||||||
|
BaseDN: "dc=example,dc=com",
|
||||||
|
BindDN: "cn=admin,dc=example,dc=com",
|
||||||
|
BindPassword: "admin-password",
|
||||||
|
Attributes: LDAPAttributes{
|
||||||
|
IDAttribute: "uid",
|
||||||
|
EmailAttribute: "mail",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
connector, err := manager.AddLDAPConnector(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "ldap-456", connector.ID)
|
||||||
|
assert.Equal(t, "Corporate LDAP", connector.Name)
|
||||||
|
assert.Equal(t, ConnectorTypeLDAP, connector.Type)
|
||||||
|
assert.Equal(t, []string{"ldap://ldap.example.com:389"}, connector.Servers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelManager_AddSAMLConnector(t *testing.T) {
|
||||||
|
mockResponse := `{"id": "saml-789", "details": {"sequence": "1", "creationDate": "2024-01-01T00:00:00Z", "changeDate": "2024-01-01T00:00:00Z", "resourceOwner": "org-1"}}`
|
||||||
|
|
||||||
|
mockClient := &mockHTTPClient{
|
||||||
|
code: http.StatusCreated,
|
||||||
|
resBody: mockResponse,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &ZitadelManager{
|
||||||
|
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||||
|
httpClient: mockClient,
|
||||||
|
credentials: &mockCredentials{token: "test-token"},
|
||||||
|
helper: JsonParser{},
|
||||||
|
}
|
||||||
|
|
||||||
|
config := SAMLConnectorConfig{
|
||||||
|
Name: "Enterprise SAML",
|
||||||
|
MetadataURL: "https://idp.example.com/metadata.xml",
|
||||||
|
}
|
||||||
|
|
||||||
|
connector, err := manager.AddSAMLConnector(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "saml-789", connector.ID)
|
||||||
|
assert.Equal(t, "Enterprise SAML", connector.Name)
|
||||||
|
assert.Equal(t, ConnectorTypeSAML, connector.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelManager_AddSAMLConnector_RequiresMetadata(t *testing.T) {
|
||||||
|
manager := &ZitadelManager{
|
||||||
|
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||||
|
httpClient: &mockHTTPClient{},
|
||||||
|
credentials: &mockCredentials{token: "test-token"},
|
||||||
|
helper: JsonParser{},
|
||||||
|
}
|
||||||
|
|
||||||
|
config := SAMLConnectorConfig{
|
||||||
|
Name: "Invalid SAML",
|
||||||
|
// Neither MetadataXML nor MetadataURL provided
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := manager.AddSAMLConnector(context.Background(), config)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "metadataXml or metadataUrl must be provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelManager_ListConnectors(t *testing.T) {
|
||||||
|
mockResponse := `{
|
||||||
|
"result": [
|
||||||
|
{
|
||||||
|
"id": "oidc-1",
|
||||||
|
"name": "Google",
|
||||||
|
"state": "IDP_STATE_ACTIVE",
|
||||||
|
"type": "IDP_TYPE_OIDC",
|
||||||
|
"oidc": {"issuer": "https://accounts.google.com", "clientId": "google-client"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "ldap-1",
|
||||||
|
"name": "AD",
|
||||||
|
"state": "IDP_STATE_INACTIVE",
|
||||||
|
"type": "IDP_TYPE_LDAP",
|
||||||
|
"ldap": {"servers": ["ldap://ad.example.com:389"], "baseDn": "dc=example,dc=com"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
mockClient := &mockHTTPClient{
|
||||||
|
code: http.StatusOK,
|
||||||
|
resBody: mockResponse,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &ZitadelManager{
|
||||||
|
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||||
|
httpClient: mockClient,
|
||||||
|
credentials: &mockCredentials{token: "test-token"},
|
||||||
|
helper: JsonParser{},
|
||||||
|
}
|
||||||
|
|
||||||
|
connectors, err := manager.ListConnectors(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, connectors, 2)
|
||||||
|
|
||||||
|
assert.Equal(t, "oidc-1", connectors[0].ID)
|
||||||
|
assert.Equal(t, "Google", connectors[0].Name)
|
||||||
|
assert.Equal(t, "active", connectors[0].State)
|
||||||
|
assert.Equal(t, ConnectorTypeOIDC, connectors[0].Type)
|
||||||
|
assert.Equal(t, "https://accounts.google.com", connectors[0].Issuer)
|
||||||
|
|
||||||
|
assert.Equal(t, "ldap-1", connectors[1].ID)
|
||||||
|
assert.Equal(t, "AD", connectors[1].Name)
|
||||||
|
assert.Equal(t, "inactive", connectors[1].State)
|
||||||
|
assert.Equal(t, ConnectorTypeLDAP, connectors[1].Type)
|
||||||
|
assert.Equal(t, []string{"ldap://ad.example.com:389"}, connectors[1].Servers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelManager_GetConnector(t *testing.T) {
|
||||||
|
mockResponse := `{
|
||||||
|
"idp": {
|
||||||
|
"id": "oidc-123",
|
||||||
|
"name": "Okta",
|
||||||
|
"state": "IDP_STATE_ACTIVE",
|
||||||
|
"type": "IDP_TYPE_OIDC",
|
||||||
|
"oidc": {"issuer": "https://okta.example.com", "clientId": "client-123"}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
mockClient := &mockHTTPClient{
|
||||||
|
code: http.StatusOK,
|
||||||
|
resBody: mockResponse,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &ZitadelManager{
|
||||||
|
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||||
|
httpClient: mockClient,
|
||||||
|
credentials: &mockCredentials{token: "test-token"},
|
||||||
|
helper: JsonParser{},
|
||||||
|
}
|
||||||
|
|
||||||
|
connector, err := manager.GetConnector(context.Background(), "oidc-123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "oidc-123", connector.ID)
|
||||||
|
assert.Equal(t, "Okta", connector.Name)
|
||||||
|
assert.Equal(t, ConnectorTypeOIDC, connector.Type)
|
||||||
|
assert.Equal(t, "https://okta.example.com", connector.Issuer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelManager_DeleteConnector(t *testing.T) {
|
||||||
|
mockClient := &mockHTTPClient{
|
||||||
|
code: http.StatusOK,
|
||||||
|
resBody: "{}",
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &ZitadelManager{
|
||||||
|
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||||
|
httpClient: mockClient,
|
||||||
|
credentials: &mockCredentials{token: "test-token"},
|
||||||
|
helper: JsonParser{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := manager.DeleteConnector(context.Background(), "oidc-123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelManager_ActivateConnector(t *testing.T) {
|
||||||
|
mockClient := &mockHTTPClient{
|
||||||
|
code: http.StatusOK,
|
||||||
|
resBody: "{}",
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &ZitadelManager{
|
||||||
|
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||||
|
httpClient: mockClient,
|
||||||
|
credentials: &mockCredentials{token: "test-token"},
|
||||||
|
helper: JsonParser{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := manager.ActivateConnector(context.Background(), "oidc-123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the request body
|
||||||
|
var reqBody map[string]string
|
||||||
|
err = json.Unmarshal([]byte(mockClient.reqBody), &reqBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "oidc-123", reqBody["idpId"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZitadelManager_DeactivateConnector(t *testing.T) {
|
||||||
|
mockClient := &mockHTTPClient{
|
||||||
|
code: http.StatusOK,
|
||||||
|
resBody: "{}",
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &ZitadelManager{
|
||||||
|
managementEndpoint: "https://zitadel.example.com/management/v1",
|
||||||
|
httpClient: mockClient,
|
||||||
|
credentials: &mockCredentials{token: "test-token"},
|
||||||
|
helper: JsonParser{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := manager.DeactivateConnector(context.Background(), "oidc-123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeState(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"IDP_STATE_ACTIVE", "active"},
|
||||||
|
{"IDP_STATE_INACTIVE", "inactive"},
|
||||||
|
{"custom", "custom"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.input, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tc.expected, normalizeState(tc.input))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected ConnectorType
|
||||||
|
}{
|
||||||
|
{"IDP_TYPE_OIDC", ConnectorTypeOIDC},
|
||||||
|
{"IDP_TYPE_OIDC_GENERIC", ConnectorTypeOIDC},
|
||||||
|
{"IDP_TYPE_LDAP", ConnectorTypeLDAP},
|
||||||
|
{"IDP_TYPE_SAML", ConnectorTypeSAML},
|
||||||
|
{"CUSTOM", ConnectorType("CUSTOM")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.input, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tc.expected, normalizeType(tc.input))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockCredentials is a mock implementation of ManagerCredentials for testing
|
||||||
|
type mockCredentials struct {
|
||||||
|
token string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||||
|
return JWTToken{AccessToken: m.token}, nil
|
||||||
|
}
|
||||||
@@ -1,263 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/oauth2/google"
|
|
||||||
admin "google.golang.org/api/admin/directory/v1"
|
|
||||||
"google.golang.org/api/option"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GoogleWorkspaceManager Google Workspace manager client instance.
|
|
||||||
type GoogleWorkspaceManager struct {
|
|
||||||
usersService *admin.UsersService
|
|
||||||
CustomerID string
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
credentials ManagerCredentials
|
|
||||||
helper ManagerHelper
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// GoogleWorkspaceClientConfig Google Workspace manager client configurations.
|
|
||||||
type GoogleWorkspaceClientConfig struct {
|
|
||||||
ServiceAccountKey string
|
|
||||||
CustomerID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GoogleWorkspaceCredentials Google Workspace authentication information.
|
|
||||||
type GoogleWorkspaceCredentials struct {
|
|
||||||
clientConfig GoogleWorkspaceClientConfig
|
|
||||||
helper ManagerHelper
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
func (gc *GoogleWorkspaceCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
|
||||||
return JWTToken{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager.
|
|
||||||
func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
httpTransport.MaxIdleConns = 5
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
Transport: httpTransport,
|
|
||||||
}
|
|
||||||
helper := JsonParser{}
|
|
||||||
|
|
||||||
if config.CustomerID == "" {
|
|
||||||
return nil, fmt.Errorf("google IdP configuration is incomplete, CustomerID is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
credentials := &GoogleWorkspaceCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: httpClient,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new Admin SDK Directory service client
|
|
||||||
adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
service, err := admin.NewService(context.Background(),
|
|
||||||
option.WithScopes(admin.AdminDirectoryUserReadonlyScope),
|
|
||||||
option.WithCredentials(adminCredentials),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &GoogleWorkspaceManager{
|
|
||||||
usersService: service.Users,
|
|
||||||
CustomerID: config.CustomerID,
|
|
||||||
httpClient: httpClient,
|
|
||||||
credentials: credentials,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
|
||||||
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserDataByID requests user data from Google Workspace via ID.
|
|
||||||
func (gm *GoogleWorkspaceManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
|
||||||
user, err := gm.usersService.Get(userID).Do()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if gm.appMetrics != nil {
|
|
||||||
gm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
|
||||||
}
|
|
||||||
|
|
||||||
userData := parseGoogleWorkspaceUser(user)
|
|
||||||
userData.AppMetadata = appMetadata
|
|
||||||
|
|
||||||
return userData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccount returns all the users for a given profile.
|
|
||||||
func (gm *GoogleWorkspaceManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
|
||||||
users, err := gm.getAllUsers()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if gm.appMetrics != nil {
|
|
||||||
gm.appMetrics.IDPMetrics().CountGetAccount()
|
|
||||||
}
|
|
||||||
|
|
||||||
for index, user := range users {
|
|
||||||
user.AppMetadata.WTAccountID = accountID
|
|
||||||
users[index] = user
|
|
||||||
}
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
|
||||||
// It returns a list of users indexed by accountID.
|
|
||||||
func (gm *GoogleWorkspaceManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
|
||||||
users, err := gm.getAllUsers()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
indexedUsers := make(map[string][]*UserData)
|
|
||||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
|
||||||
|
|
||||||
if gm.appMetrics != nil {
|
|
||||||
gm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
|
||||||
}
|
|
||||||
|
|
||||||
return indexedUsers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getAllUsers returns all users in a Google Workspace account filtered by customer ID.
|
|
||||||
func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) {
|
|
||||||
users := make([]*UserData, 0)
|
|
||||||
pageToken := ""
|
|
||||||
for {
|
|
||||||
call := gm.usersService.List().Customer(gm.CustomerID).MaxResults(500)
|
|
||||||
if pageToken != "" {
|
|
||||||
call.PageToken(pageToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := call.Do()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, user := range resp.Users {
|
|
||||||
users = append(users, parseGoogleWorkspaceUser(user))
|
|
||||||
}
|
|
||||||
|
|
||||||
pageToken = resp.NextPageToken
|
|
||||||
if pageToken == "" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateUser creates a new user in Google Workspace and sends an invitation.
|
|
||||||
func (gm *GoogleWorkspaceManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
|
||||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserByEmail searches users with a given email.
|
|
||||||
// If no users have been found, this function returns an empty list.
|
|
||||||
func (gm *GoogleWorkspaceManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
|
||||||
user, err := gm.usersService.Get(email).Do()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if gm.appMetrics != nil {
|
|
||||||
gm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
|
||||||
}
|
|
||||||
|
|
||||||
users := make([]*UserData, 0)
|
|
||||||
users = append(users, parseGoogleWorkspaceUser(user))
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InviteUserByID resend invitations to users who haven't activated,
|
|
||||||
// their accounts prior to the expiration period.
|
|
||||||
func (gm *GoogleWorkspaceManager) InviteUserByID(_ context.Context, _ string) error {
|
|
||||||
return fmt.Errorf("method InviteUserByID not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteUser from GoogleWorkspace.
|
|
||||||
func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) error {
|
|
||||||
if err := gm.usersService.Delete(userID).Do(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if gm.appMetrics != nil {
|
|
||||||
gm.appMetrics.IDPMetrics().CountDeleteUser()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
|
|
||||||
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
|
|
||||||
// If that fails, it falls back to using the default Google credentials path.
|
|
||||||
// It returns the retrieved credentials or an error if unsuccessful.
|
|
||||||
func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) {
|
|
||||||
log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key")
|
|
||||||
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to decode service account key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
creds, err := google.CredentialsFromJSON(
|
|
||||||
context.Background(),
|
|
||||||
decodeKey,
|
|
||||||
admin.AdminDirectoryUserReadonlyScope,
|
|
||||||
)
|
|
||||||
if err == nil {
|
|
||||||
// No need to fallback to the default Google credentials path
|
|
||||||
return creds, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
|
|
||||||
log.WithContext(ctx).Debug("falling back to default google credentials location")
|
|
||||||
|
|
||||||
creds, err = google.FindDefaultCredentials(
|
|
||||||
context.Background(),
|
|
||||||
admin.AdminDirectoryUserReadonlyScope,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return creds, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseGoogleWorkspaceUser parse google user to UserData.
|
|
||||||
func parseGoogleWorkspaceUser(user *admin.User) *UserData {
|
|
||||||
return &UserData{
|
|
||||||
ID: user.Id,
|
|
||||||
Email: user.PrimaryEmail,
|
|
||||||
Name: user.Name.FullName,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -11,24 +11,25 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
// UnsetAccountID is a special key to map users without an account ID
|
|
||||||
UnsetAccountID = "unset"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Manager idp manager interface
|
// Manager idp manager interface
|
||||||
|
// Note: NetBird is the single source of truth for authorization data (roles, account membership, invite status).
|
||||||
|
// The IdP only stores identity information (email, name, credentials).
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error
|
// CreateUser creates a new user in the IdP. Returns basic user data (ID, email, name).
|
||||||
GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
|
CreateUser(ctx context.Context, email, name string) (*UserData, error)
|
||||||
GetAccount(ctx context.Context, accountId string) ([]*UserData, error)
|
// GetUserDataByID retrieves user identity data from the IdP by user ID.
|
||||||
GetAllAccounts(ctx context.Context) (map[string][]*UserData, error)
|
GetUserDataByID(ctx context.Context, userId string) (*UserData, error)
|
||||||
CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
|
// GetUserByEmail searches for users by email address.
|
||||||
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
|
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
|
||||||
|
// GetAllUsers returns all users from the IdP for cache warming.
|
||||||
|
GetAllUsers(ctx context.Context) ([]*UserData, error)
|
||||||
|
// InviteUserByID resends an invitation to a user who hasn't completed signup.
|
||||||
InviteUserByID(ctx context.Context, userID string) error
|
InviteUserByID(ctx context.Context, userID string) error
|
||||||
|
// DeleteUser removes a user from the IdP.
|
||||||
DeleteUser(ctx context.Context, userID string) error
|
DeleteUser(ctx context.Context, userID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientConfig defines common client configuration for all IdP manager
|
// ClientConfig defines common client configuration for the IdP manager
|
||||||
type ClientConfig struct {
|
type ClientConfig struct {
|
||||||
Issuer string
|
Issuer string
|
||||||
TokenEndpoint string
|
TokenEndpoint string
|
||||||
@@ -42,13 +43,10 @@ type ExtraConfig map[string]string
|
|||||||
|
|
||||||
// Config an idp configuration struct to be loaded from management server's config file
|
// Config an idp configuration struct to be loaded from management server's config file
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ManagerType string
|
ManagerType string
|
||||||
ClientConfig *ClientConfig
|
ClientConfig *ClientConfig
|
||||||
ExtraConfig ExtraConfig
|
ExtraConfig ExtraConfig
|
||||||
Auth0ClientCredentials *Auth0ClientConfig
|
ZitadelClientCredentials *ZitadelClientConfig
|
||||||
AzureClientCredentials *AzureClientConfig
|
|
||||||
KeycloakClientCredentials *KeycloakClientConfig
|
|
||||||
ZitadelClientCredentials *ZitadelClientConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
||||||
@@ -67,11 +65,12 @@ type ManagerHelper interface {
|
|||||||
Unmarshal(data []byte, v interface{}) error
|
Unmarshal(data []byte, v interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserData represents identity information from the IdP.
|
||||||
|
// Note: Authorization data (account membership, roles, invite status) is stored in NetBird's DB.
|
||||||
type UserData struct {
|
type UserData struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
ID string `json:"user_id"`
|
ID string `json:"user_id"`
|
||||||
AppMetadata AppMetadata `json:"app_metadata"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *UserData) MarshalBinary() (data []byte, err error) {
|
func (u *UserData) MarshalBinary() (data []byte, err error) {
|
||||||
@@ -91,15 +90,6 @@ func (u *UserData) Unmarshal(data []byte) (err error) {
|
|||||||
return json.Unmarshal(data, &u)
|
return json.Unmarshal(data, &u)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AppMetadata user app metadata to associate with a profile
|
|
||||||
type AppMetadata struct {
|
|
||||||
// WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP
|
|
||||||
// maps to wt_account_id when json.marshal
|
|
||||||
WTAccountID string `json:"wt_account_id,omitempty"`
|
|
||||||
WTPendingInvite *bool `json:"wt_pending_invite,omitempty"`
|
|
||||||
WTInvitedBy string `json:"wt_invited_by_email,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// JWTToken a JWT object that holds information of a token
|
// JWTToken a JWT object that holds information of a token
|
||||||
type JWTToken struct {
|
type JWTToken struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
@@ -109,7 +99,8 @@ type JWTToken struct {
|
|||||||
TokenType string `json:"token_type"`
|
TokenType string `json:"token_type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager returns a new idp manager based on the configuration that it receives
|
// NewManager returns a new idp manager based on the configuration that it receives.
|
||||||
|
// Only Zitadel is supported as the IdP manager.
|
||||||
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
|
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
|
||||||
if config.ClientConfig != nil {
|
if config.ClientConfig != nil {
|
||||||
config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/")
|
config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/")
|
||||||
@@ -118,46 +109,6 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
|||||||
switch strings.ToLower(config.ManagerType) {
|
switch strings.ToLower(config.ManagerType) {
|
||||||
case "none", "":
|
case "none", "":
|
||||||
return nil, nil //nolint:nilnil
|
return nil, nil //nolint:nilnil
|
||||||
case "auth0":
|
|
||||||
auth0ClientConfig := config.Auth0ClientCredentials
|
|
||||||
if config.ClientConfig != nil {
|
|
||||||
auth0ClientConfig = &Auth0ClientConfig{
|
|
||||||
Audience: config.ExtraConfig["Audience"],
|
|
||||||
AuthIssuer: config.ClientConfig.Issuer,
|
|
||||||
ClientID: config.ClientConfig.ClientID,
|
|
||||||
ClientSecret: config.ClientConfig.ClientSecret,
|
|
||||||
GrantType: config.ClientConfig.GrantType,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewAuth0Manager(*auth0ClientConfig, appMetrics)
|
|
||||||
case "azure":
|
|
||||||
azureClientConfig := config.AzureClientCredentials
|
|
||||||
if config.ClientConfig != nil {
|
|
||||||
azureClientConfig = &AzureClientConfig{
|
|
||||||
ClientID: config.ClientConfig.ClientID,
|
|
||||||
ClientSecret: config.ClientConfig.ClientSecret,
|
|
||||||
GrantType: config.ClientConfig.GrantType,
|
|
||||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
|
||||||
ObjectID: config.ExtraConfig["ObjectId"],
|
|
||||||
GraphAPIEndpoint: config.ExtraConfig["GraphApiEndpoint"],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewAzureManager(*azureClientConfig, appMetrics)
|
|
||||||
case "keycloak":
|
|
||||||
keycloakClientConfig := config.KeycloakClientCredentials
|
|
||||||
if config.ClientConfig != nil {
|
|
||||||
keycloakClientConfig = &KeycloakClientConfig{
|
|
||||||
ClientID: config.ClientConfig.ClientID,
|
|
||||||
ClientSecret: config.ClientConfig.ClientSecret,
|
|
||||||
GrantType: config.ClientConfig.GrantType,
|
|
||||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
|
||||||
AdminEndpoint: config.ExtraConfig["AdminEndpoint"],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewKeycloakManager(*keycloakClientConfig, appMetrics)
|
|
||||||
case "zitadel":
|
case "zitadel":
|
||||||
zitadelClientConfig := config.ZitadelClientCredentials
|
zitadelClientConfig := config.ZitadelClientCredentials
|
||||||
if config.ClientConfig != nil {
|
if config.ClientConfig != nil {
|
||||||
@@ -172,42 +123,7 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
|||||||
}
|
}
|
||||||
|
|
||||||
return NewZitadelManager(*zitadelClientConfig, appMetrics)
|
return NewZitadelManager(*zitadelClientConfig, appMetrics)
|
||||||
case "authentik":
|
|
||||||
authentikConfig := AuthentikClientConfig{
|
|
||||||
Issuer: config.ClientConfig.Issuer,
|
|
||||||
ClientID: config.ClientConfig.ClientID,
|
|
||||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
|
||||||
GrantType: config.ClientConfig.GrantType,
|
|
||||||
Username: config.ExtraConfig["Username"],
|
|
||||||
Password: config.ExtraConfig["Password"],
|
|
||||||
}
|
|
||||||
return NewAuthentikManager(authentikConfig, appMetrics)
|
|
||||||
case "okta":
|
|
||||||
oktaClientConfig := OktaClientConfig{
|
|
||||||
Issuer: config.ClientConfig.Issuer,
|
|
||||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
|
||||||
GrantType: config.ClientConfig.GrantType,
|
|
||||||
APIToken: config.ExtraConfig["ApiToken"],
|
|
||||||
}
|
|
||||||
return NewOktaManager(oktaClientConfig, appMetrics)
|
|
||||||
case "google":
|
|
||||||
googleClientConfig := GoogleWorkspaceClientConfig{
|
|
||||||
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
|
|
||||||
CustomerID: config.ExtraConfig["CustomerId"],
|
|
||||||
}
|
|
||||||
return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics)
|
|
||||||
case "jumpcloud":
|
|
||||||
jumpcloudConfig := JumpCloudClientConfig{
|
|
||||||
APIToken: config.ExtraConfig["ApiToken"],
|
|
||||||
}
|
|
||||||
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
|
|
||||||
case "pocketid":
|
|
||||||
pocketidConfig := PocketIdClientConfig{
|
|
||||||
APIToken: config.ExtraConfig["ApiToken"],
|
|
||||||
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
|
|
||||||
}
|
|
||||||
return NewPocketIdManager(pocketidConfig, appMetrics)
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
return nil, fmt.Errorf("unsupported IdP manager type: %s (only 'zitadel' is supported)", config.ManagerType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,257 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
v1 "github.com/TheJumpCloud/jcapi-go/v1"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
contentType = "application/json"
|
|
||||||
accept = "application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
// JumpCloudManager JumpCloud manager client instance.
|
|
||||||
type JumpCloudManager struct {
|
|
||||||
client *v1.APIClient
|
|
||||||
apiToken string
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
credentials ManagerCredentials
|
|
||||||
helper ManagerHelper
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// JumpCloudClientConfig JumpCloud manager client configurations.
|
|
||||||
type JumpCloudClientConfig struct {
|
|
||||||
APIToken string
|
|
||||||
}
|
|
||||||
|
|
||||||
// JumpCloudCredentials JumpCloud authentication information.
|
|
||||||
type JumpCloudCredentials struct {
|
|
||||||
clientConfig JumpCloudClientConfig
|
|
||||||
helper ManagerHelper
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewJumpCloudManager creates a new instance of the JumpCloudManager.
|
|
||||||
func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppMetrics) (*JumpCloudManager, error) {
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
httpTransport.MaxIdleConns = 5
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
Transport: httpTransport,
|
|
||||||
}
|
|
||||||
helper := JsonParser{}
|
|
||||||
|
|
||||||
if config.APIToken == "" {
|
|
||||||
return nil, fmt.Errorf("jumpCloud IdP configuration is incomplete, ApiToken is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
client := v1.NewAPIClient(v1.NewConfiguration())
|
|
||||||
credentials := &JumpCloudCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: httpClient,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &JumpCloudManager{
|
|
||||||
client: client,
|
|
||||||
apiToken: config.APIToken,
|
|
||||||
httpClient: httpClient,
|
|
||||||
credentials: credentials,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authenticate retrieves access token to use the JumpCloud user API.
|
|
||||||
func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
|
||||||
return JWTToken{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (jm *JumpCloudManager) authenticationContext() context.Context {
|
|
||||||
return context.WithValue(context.Background(), v1.ContextAPIKey, v1.APIKey{
|
|
||||||
Key: jm.apiToken,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
|
||||||
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserDataByID requests user data from JumpCloud via ID.
|
|
||||||
func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
|
||||||
authCtx := jm.authenticationContext()
|
|
||||||
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if jm.appMetrics != nil {
|
|
||||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
if jm.appMetrics != nil {
|
|
||||||
jm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
|
||||||
}
|
|
||||||
|
|
||||||
userData := parseJumpCloudUser(user)
|
|
||||||
userData.AppMetadata = appMetadata
|
|
||||||
|
|
||||||
return userData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccount returns all the users for a given profile.
|
|
||||||
func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
|
||||||
authCtx := jm.authenticationContext()
|
|
||||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if jm.appMetrics != nil {
|
|
||||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
if jm.appMetrics != nil {
|
|
||||||
jm.appMetrics.IDPMetrics().CountGetAccount()
|
|
||||||
}
|
|
||||||
|
|
||||||
users := make([]*UserData, 0)
|
|
||||||
for _, user := range userList.Results {
|
|
||||||
userData := parseJumpCloudUser(user)
|
|
||||||
userData.AppMetadata.WTAccountID = accountID
|
|
||||||
|
|
||||||
users = append(users, userData)
|
|
||||||
}
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
|
||||||
// It returns a list of users indexed by accountID.
|
|
||||||
func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
|
||||||
authCtx := jm.authenticationContext()
|
|
||||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if jm.appMetrics != nil {
|
|
||||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
if jm.appMetrics != nil {
|
|
||||||
jm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
|
||||||
}
|
|
||||||
|
|
||||||
indexedUsers := make(map[string][]*UserData)
|
|
||||||
for _, user := range userList.Results {
|
|
||||||
userData := parseJumpCloudUser(user)
|
|
||||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
|
||||||
}
|
|
||||||
|
|
||||||
return indexedUsers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
|
|
||||||
func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
|
||||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserByEmail searches users with a given email.
|
|
||||||
// If no users have been found, this function returns an empty list.
|
|
||||||
func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
|
||||||
searchFilter := map[string]interface{}{
|
|
||||||
"searchFilter": map[string]interface{}{
|
|
||||||
"filter": []string{email},
|
|
||||||
"fields": []string{"email"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
authCtx := jm.authenticationContext()
|
|
||||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, searchFilter)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if jm.appMetrics != nil {
|
|
||||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
if jm.appMetrics != nil {
|
|
||||||
jm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
|
||||||
}
|
|
||||||
|
|
||||||
usersData := make([]*UserData, 0)
|
|
||||||
for _, user := range userList.Results {
|
|
||||||
usersData = append(usersData, parseJumpCloudUser(user))
|
|
||||||
}
|
|
||||||
|
|
||||||
return usersData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InviteUserByID resend invitations to users who haven't activated,
|
|
||||||
// their accounts prior to the expiration period.
|
|
||||||
func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
|
|
||||||
return fmt.Errorf("method InviteUserByID not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteUser from jumpCloud directory
|
|
||||||
func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
|
|
||||||
authCtx := jm.authenticationContext()
|
|
||||||
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if jm.appMetrics != nil {
|
|
||||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
if jm.appMetrics != nil {
|
|
||||||
jm.appMetrics.IDPMetrics().CountDeleteUser()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseJumpCloudUser parse JumpCloud system user returned from API V1 to UserData.
|
|
||||||
func parseJumpCloudUser(user v1.Systemuserreturn) *UserData {
|
|
||||||
names := []string{user.Firstname, user.Middlename, user.Lastname}
|
|
||||||
return &UserData{
|
|
||||||
Email: user.Email,
|
|
||||||
Name: strings.Join(names, " "),
|
|
||||||
ID: user.Id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewJumpCloudManager(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
name string
|
|
||||||
inputConfig JumpCloudClientConfig
|
|
||||||
assertErrFunc require.ErrorAssertionFunc
|
|
||||||
assertErrFuncMessage string
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultTestConfig := JumpCloudClientConfig{
|
|
||||||
APIToken: "test123",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase1 := test{
|
|
||||||
name: "Good Configuration",
|
|
||||||
inputConfig: defaultTestConfig,
|
|
||||||
assertErrFunc: require.NoError,
|
|
||||||
assertErrFuncMessage: "shouldn't return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase2Config := defaultTestConfig
|
|
||||||
testCase2Config.APIToken = ""
|
|
||||||
|
|
||||||
testCase2 := test{
|
|
||||||
name: "Missing APIToken Configuration",
|
|
||||||
inputConfig: testCase2Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []test{testCase1, testCase2} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
_, err := NewJumpCloudManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,439 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
)
|
|
||||||
|
|
||||||
// KeycloakManager keycloak manager client instance.
|
|
||||||
type KeycloakManager struct {
|
|
||||||
adminEndpoint string
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
credentials ManagerCredentials
|
|
||||||
helper ManagerHelper
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeycloakClientConfig keycloak manager client configurations.
|
|
||||||
type KeycloakClientConfig struct {
|
|
||||||
ClientID string
|
|
||||||
ClientSecret string
|
|
||||||
AdminEndpoint string
|
|
||||||
TokenEndpoint string
|
|
||||||
GrantType string
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeycloakCredentials keycloak authentication information.
|
|
||||||
type KeycloakCredentials struct {
|
|
||||||
clientConfig KeycloakClientConfig
|
|
||||||
helper ManagerHelper
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
jwtToken JWTToken
|
|
||||||
mux sync.Mutex
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// keycloakUserAttributes holds additional user data fields.
|
|
||||||
type keycloakUserAttributes map[string][]string
|
|
||||||
|
|
||||||
// keycloakProfile represents a keycloak user profile response.
|
|
||||||
type keycloakProfile struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
CreatedTimestamp int64 `json:"createdTimestamp"`
|
|
||||||
Username string `json:"username"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
Attributes keycloakUserAttributes `json:"attributes"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewKeycloakManager creates a new instance of the KeycloakManager.
|
|
||||||
func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
httpTransport.MaxIdleConns = 5
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
Transport: httpTransport,
|
|
||||||
}
|
|
||||||
helper := JsonParser{}
|
|
||||||
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, clientID is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.ClientSecret == "" {
|
|
||||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, ClientSecret is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, TokenEndpoint is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.AdminEndpoint == "" {
|
|
||||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, AdminEndpoint is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.GrantType == "" {
|
|
||||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, GrantType is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
credentials := &KeycloakCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: httpClient,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &KeycloakManager{
|
|
||||||
adminEndpoint: config.AdminEndpoint,
|
|
||||||
httpClient: httpClient,
|
|
||||||
credentials: credentials,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from keycloak.
|
|
||||||
func (kc *KeycloakCredentials) jwtStillValid() bool {
|
|
||||||
return !kc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(kc.jwtToken.expiresInTime)
|
|
||||||
}
|
|
||||||
|
|
||||||
// requestJWTToken performs request to get jwt token.
|
|
||||||
func (kc *KeycloakCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
|
||||||
data := url.Values{}
|
|
||||||
data.Set("client_id", kc.clientConfig.ClientID)
|
|
||||||
data.Set("client_secret", kc.clientConfig.ClientSecret)
|
|
||||||
data.Set("grant_type", kc.clientConfig.GrantType)
|
|
||||||
|
|
||||||
payload := strings.NewReader(data.Encode())
|
|
||||||
req, err := http.NewRequest(http.MethodPost, kc.clientConfig.TokenEndpoint, payload)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debug("requesting new jwt token for keycloak idp manager")
|
|
||||||
|
|
||||||
resp, err := kc.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if kc.appMetrics != nil {
|
|
||||||
kc.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("unable to get keycloak token, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
|
||||||
func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
|
||||||
jwtToken := JWTToken{}
|
|
||||||
body, err := io.ReadAll(rawBody)
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = kc.helper.Unmarshal(body, &jwtToken)
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
|
||||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exp maps into exp from jwt token
|
|
||||||
var IssuedAt struct{ Exp int64 }
|
|
||||||
err = kc.helper.Unmarshal(data, &IssuedAt)
|
|
||||||
if err != nil {
|
|
||||||
return jwtToken, err
|
|
||||||
}
|
|
||||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
|
||||||
|
|
||||||
return jwtToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authenticate retrieves access token to use the keycloak Management API.
|
|
||||||
func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
|
||||||
kc.mux.Lock()
|
|
||||||
defer kc.mux.Unlock()
|
|
||||||
|
|
||||||
if kc.appMetrics != nil {
|
|
||||||
kc.appMetrics.IDPMetrics().CountAuthenticate()
|
|
||||||
}
|
|
||||||
|
|
||||||
// reuse the token without requesting a new one if it is not expired,
|
|
||||||
// and if expiry time is sufficient time available to make a request.
|
|
||||||
if kc.jwtStillValid() {
|
|
||||||
return kc.jwtToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := kc.requestJWTToken(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return kc.jwtToken, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
jwtToken, err := kc.parseRequestJWTResponse(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return kc.jwtToken, err
|
|
||||||
}
|
|
||||||
|
|
||||||
kc.jwtToken = jwtToken
|
|
||||||
|
|
||||||
return kc.jwtToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateUser creates a new user in keycloak Idp and sends an invite.
|
|
||||||
func (km *KeycloakManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
|
||||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserByEmail searches users with a given email.
|
|
||||||
// If no users have been found, this function returns an empty list.
|
|
||||||
func (km *KeycloakManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
|
||||||
q := url.Values{}
|
|
||||||
q.Add("email", email)
|
|
||||||
q.Add("exact", "true")
|
|
||||||
|
|
||||||
body, err := km.get(ctx, "users", q)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if km.appMetrics != nil {
|
|
||||||
km.appMetrics.IDPMetrics().CountGetUserByEmail()
|
|
||||||
}
|
|
||||||
|
|
||||||
profiles := make([]keycloakProfile, 0)
|
|
||||||
err = km.helper.Unmarshal(body, &profiles)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
users := make([]*UserData, 0)
|
|
||||||
for _, profile := range profiles {
|
|
||||||
users = append(users, profile.userData())
|
|
||||||
}
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserDataByID requests user data from keycloak via ID.
|
|
||||||
func (km *KeycloakManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) {
|
|
||||||
body, err := km.get(ctx, "users/"+userID, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if km.appMetrics != nil {
|
|
||||||
km.appMetrics.IDPMetrics().CountGetUserDataByID()
|
|
||||||
}
|
|
||||||
|
|
||||||
var profile keycloakProfile
|
|
||||||
err = km.helper.Unmarshal(body, &profile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return profile.userData(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccount returns all the users for a given account profile.
|
|
||||||
func (km *KeycloakManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
|
||||||
profiles, err := km.fetchAllUserProfiles(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if km.appMetrics != nil {
|
|
||||||
km.appMetrics.IDPMetrics().CountGetAccount()
|
|
||||||
}
|
|
||||||
|
|
||||||
users := make([]*UserData, 0)
|
|
||||||
for _, profile := range profiles {
|
|
||||||
userData := profile.userData()
|
|
||||||
userData.AppMetadata.WTAccountID = accountID
|
|
||||||
|
|
||||||
users = append(users, userData)
|
|
||||||
}
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
|
||||||
// It returns a list of users indexed by accountID.
|
|
||||||
func (km *KeycloakManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
|
||||||
profiles, err := km.fetchAllUserProfiles(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if km.appMetrics != nil {
|
|
||||||
km.appMetrics.IDPMetrics().CountGetAllAccounts()
|
|
||||||
}
|
|
||||||
|
|
||||||
indexedUsers := make(map[string][]*UserData)
|
|
||||||
for _, profile := range profiles {
|
|
||||||
userData := profile.userData()
|
|
||||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
|
||||||
}
|
|
||||||
|
|
||||||
return indexedUsers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
|
||||||
func (km *KeycloakManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InviteUserByID resend invitations to users who haven't activated,
|
|
||||||
// their accounts prior to the expiration period.
|
|
||||||
func (km *KeycloakManager) InviteUserByID(_ context.Context, _ string) error {
|
|
||||||
return fmt.Errorf("method InviteUserByID not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteUser from Keycloak by user ID.
|
|
||||||
func (km *KeycloakManager) DeleteUser(ctx context.Context, userID string) error {
|
|
||||||
jwtToken, err := km.credentials.Authenticate(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID))
|
|
||||||
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
|
||||||
req.Header.Add("content-type", "application/json")
|
|
||||||
|
|
||||||
if km.appMetrics != nil {
|
|
||||||
km.appMetrics.IDPMetrics().CountDeleteUser()
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := km.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if km.appMetrics != nil {
|
|
||||||
km.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close() // nolint
|
|
||||||
|
|
||||||
// In the docs, they specified 200, but in the endpoints, they return 204
|
|
||||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
|
||||||
if km.appMetrics != nil {
|
|
||||||
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloakProfile, error) {
|
|
||||||
totalUsers, err := km.totalUsersCount(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
q := url.Values{}
|
|
||||||
q.Add("max", fmt.Sprint(*totalUsers))
|
|
||||||
|
|
||||||
body, err := km.get(ctx, "users", q)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
profiles := make([]keycloakProfile, 0)
|
|
||||||
err = km.helper.Unmarshal(body, &profiles)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return profiles, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// get perform Get requests.
|
|
||||||
func (km *KeycloakManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
|
||||||
jwtToken, err := km.credentials.Authenticate(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
reqURL := fmt.Sprintf("%s/%s?%s", km.adminEndpoint, resource, q.Encode())
|
|
||||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
|
||||||
req.Header.Add("content-type", "application/json")
|
|
||||||
|
|
||||||
resp, err := km.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if km.appMetrics != nil {
|
|
||||||
km.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if km.appMetrics != nil {
|
|
||||||
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return io.ReadAll(resp.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
// totalUsersCount returns the total count of all user created.
|
|
||||||
// Used when fetching all registered accounts with pagination.
|
|
||||||
func (km *KeycloakManager) totalUsersCount(ctx context.Context) (*int, error) {
|
|
||||||
body, err := km.get(ctx, "users/count", nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
count, err := strconv.Atoi(string(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// userData construct user data from keycloak profile.
|
|
||||||
func (kp keycloakProfile) userData() *UserData {
|
|
||||||
return &UserData{
|
|
||||||
Email: kp.Email,
|
|
||||||
Name: kp.Username,
|
|
||||||
ID: kp.ID,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,310 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewKeycloakManager(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
name string
|
|
||||||
inputConfig KeycloakClientConfig
|
|
||||||
assertErrFunc require.ErrorAssertionFunc
|
|
||||||
assertErrFuncMessage string
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultTestConfig := KeycloakClientConfig{
|
|
||||||
ClientID: "client_id",
|
|
||||||
ClientSecret: "client_secret",
|
|
||||||
AdminEndpoint: "https://localhost:8080/auth/admin/realms/test123",
|
|
||||||
TokenEndpoint: "https://localhost:8080/auth/realms/test123/protocol/openid-connect/token",
|
|
||||||
GrantType: "client_credentials",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase1 := test{
|
|
||||||
name: "Good Configuration",
|
|
||||||
inputConfig: defaultTestConfig,
|
|
||||||
assertErrFunc: require.NoError,
|
|
||||||
assertErrFuncMessage: "shouldn't return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase2Config := defaultTestConfig
|
|
||||||
testCase2Config.ClientID = ""
|
|
||||||
|
|
||||||
testCase2 := test{
|
|
||||||
name: "Missing ClientID Configuration",
|
|
||||||
inputConfig: testCase2Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase3Config := defaultTestConfig
|
|
||||||
testCase3Config.ClientSecret = ""
|
|
||||||
|
|
||||||
testCase3 := test{
|
|
||||||
name: "Missing ClientSecret Configuration",
|
|
||||||
inputConfig: testCase3Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase4Config := defaultTestConfig
|
|
||||||
testCase4Config.TokenEndpoint = ""
|
|
||||||
|
|
||||||
testCase4 := test{
|
|
||||||
name: "Missing TokenEndpoint Configuration",
|
|
||||||
inputConfig: testCase3Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCase5Config := defaultTestConfig
|
|
||||||
testCase5Config.GrantType = ""
|
|
||||||
|
|
||||||
testCase5 := test{
|
|
||||||
name: "Missing GrantType Configuration",
|
|
||||||
inputConfig: testCase3Config,
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
_, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeycloakRequestJWTToken(t *testing.T) {
|
|
||||||
|
|
||||||
type requestJWTTokenTest struct {
|
|
||||||
name string
|
|
||||||
inputCode int
|
|
||||||
inputRespBody string
|
|
||||||
helper ManagerHelper
|
|
||||||
expectedFuncExitErrDiff error
|
|
||||||
expectedToken string
|
|
||||||
}
|
|
||||||
exp := 5
|
|
||||||
token := newTestJWT(t, exp)
|
|
||||||
|
|
||||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
|
||||||
name: "Good JWT Response",
|
|
||||||
inputCode: 200,
|
|
||||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedToken: token,
|
|
||||||
}
|
|
||||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
|
||||||
name: "Request Bad Status Code",
|
|
||||||
inputCode: 400,
|
|
||||||
inputRespBody: "{}",
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
|
|
||||||
expectedToken: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
|
|
||||||
jwtReqClient := mockHTTPClient{
|
|
||||||
resBody: testCase.inputRespBody,
|
|
||||||
code: testCase.inputCode,
|
|
||||||
}
|
|
||||||
config := KeycloakClientConfig{}
|
|
||||||
|
|
||||||
creds := KeycloakCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: &jwtReqClient,
|
|
||||||
helper: testCase.helper,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := creds.requestJWTToken(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
if testCase.expectedFuncExitErrDiff != nil {
|
|
||||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
|
||||||
} else {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
defer resp.Body.Close()
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
assert.NoError(t, err, "unable to read the response body")
|
|
||||||
|
|
||||||
jwtToken := JWTToken{}
|
|
||||||
err = testCase.helper.Unmarshal(body, &jwtToken)
|
|
||||||
assert.NoError(t, err, "unable to parse the json input")
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeycloakParseRequestJWTResponse(t *testing.T) {
|
|
||||||
type parseRequestJWTResponseTest struct {
|
|
||||||
name string
|
|
||||||
inputRespBody string
|
|
||||||
helper ManagerHelper
|
|
||||||
expectedToken string
|
|
||||||
expectedExpiresIn int
|
|
||||||
assertErrFunc assert.ErrorAssertionFunc
|
|
||||||
assertErrFuncMessage string
|
|
||||||
}
|
|
||||||
|
|
||||||
exp := 100
|
|
||||||
token := newTestJWT(t, exp)
|
|
||||||
|
|
||||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
|
||||||
name: "Parse Good JWT Body",
|
|
||||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedToken: token,
|
|
||||||
expectedExpiresIn: exp,
|
|
||||||
assertErrFunc: assert.NoError,
|
|
||||||
assertErrFuncMessage: "no error was expected",
|
|
||||||
}
|
|
||||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
|
||||||
name: "Parse Bad json JWT Body",
|
|
||||||
inputRespBody: "",
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedToken: "",
|
|
||||||
expectedExpiresIn: 0,
|
|
||||||
assertErrFunc: assert.Error,
|
|
||||||
assertErrFuncMessage: "json error was expected",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
|
|
||||||
config := KeycloakClientConfig{}
|
|
||||||
|
|
||||||
creds := KeycloakCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
helper: testCase.helper,
|
|
||||||
}
|
|
||||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
|
||||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeycloakJwtStillValid(t *testing.T) {
|
|
||||||
type jwtStillValidTest struct {
|
|
||||||
name string
|
|
||||||
inputTime time.Time
|
|
||||||
expectedResult bool
|
|
||||||
message string
|
|
||||||
}
|
|
||||||
|
|
||||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
|
||||||
name: "JWT still valid",
|
|
||||||
inputTime: time.Now().Add(10 * time.Second),
|
|
||||||
expectedResult: true,
|
|
||||||
message: "should be true",
|
|
||||||
}
|
|
||||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
|
||||||
name: "JWT is invalid",
|
|
||||||
inputTime: time.Now(),
|
|
||||||
expectedResult: false,
|
|
||||||
message: "should be false",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
config := KeycloakClientConfig{}
|
|
||||||
|
|
||||||
creds := KeycloakCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
}
|
|
||||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeycloakAuthenticate(t *testing.T) {
|
|
||||||
type authenticateTest struct {
|
|
||||||
name string
|
|
||||||
inputCode int
|
|
||||||
inputResBody string
|
|
||||||
inputExpireToken time.Time
|
|
||||||
helper ManagerHelper
|
|
||||||
expectedFuncExitErrDiff error
|
|
||||||
expectedCode int
|
|
||||||
expectedToken string
|
|
||||||
}
|
|
||||||
exp := 5
|
|
||||||
token := newTestJWT(t, exp)
|
|
||||||
|
|
||||||
authenticateTestCase1 := authenticateTest{
|
|
||||||
name: "Get Cached token",
|
|
||||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedFuncExitErrDiff: nil,
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
authenticateTestCase2 := authenticateTest{
|
|
||||||
name: "Get Good JWT Response",
|
|
||||||
inputCode: 200,
|
|
||||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: token,
|
|
||||||
}
|
|
||||||
|
|
||||||
authenticateTestCase3 := authenticateTest{
|
|
||||||
name: "Get Bad Status Code",
|
|
||||||
inputCode: 400,
|
|
||||||
inputResBody: "{}",
|
|
||||||
helper: JsonParser{},
|
|
||||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
|
|
||||||
expectedCode: 200,
|
|
||||||
expectedToken: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
|
|
||||||
jwtReqClient := mockHTTPClient{
|
|
||||||
resBody: testCase.inputResBody,
|
|
||||||
code: testCase.inputCode,
|
|
||||||
}
|
|
||||||
config := KeycloakClientConfig{}
|
|
||||||
|
|
||||||
creds := KeycloakCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: &jwtReqClient,
|
|
||||||
helper: testCase.helper,
|
|
||||||
}
|
|
||||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
|
||||||
|
|
||||||
_, err := creds.Authenticate(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
if testCase.expectedFuncExitErrDiff != nil {
|
|
||||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
|
||||||
} else {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -4,52 +4,26 @@ import "context"
|
|||||||
|
|
||||||
// MockIDP is a mock implementation of the IDP interface
|
// MockIDP is a mock implementation of the IDP interface
|
||||||
type MockIDP struct {
|
type MockIDP struct {
|
||||||
UpdateUserAppMetadataFunc func(ctx context.Context, userId string, appMetadata AppMetadata) error
|
CreateUserFunc func(ctx context.Context, email, name string) (*UserData, error)
|
||||||
GetUserDataByIDFunc func(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
|
GetUserDataByIDFunc func(ctx context.Context, userId string) (*UserData, error)
|
||||||
GetAccountFunc func(ctx context.Context, accountId string) ([]*UserData, error)
|
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
|
||||||
GetAllAccountsFunc func(ctx context.Context) (map[string][]*UserData, error)
|
GetAllUsersFunc func(ctx context.Context) ([]*UserData, error)
|
||||||
CreateUserFunc func(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
|
InviteUserByIDFunc func(ctx context.Context, userID string) error
|
||||||
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
|
DeleteUserFunc func(ctx context.Context, userID string) error
|
||||||
InviteUserByIDFunc func(ctx context.Context, userID string) error
|
|
||||||
DeleteUserFunc func(ctx context.Context, userID string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateUserAppMetadata is a mock implementation of the IDP interface UpdateUserAppMetadata method
|
|
||||||
func (m *MockIDP) UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error {
|
|
||||||
if m.UpdateUserAppMetadataFunc != nil {
|
|
||||||
return m.UpdateUserAppMetadataFunc(ctx, userId, appMetadata)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
|
|
||||||
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
|
|
||||||
if m.GetUserDataByIDFunc != nil {
|
|
||||||
return m.GetUserDataByIDFunc(ctx, userId, appMetadata)
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccount is a mock implementation of the IDP interface GetAccount method
|
|
||||||
func (m *MockIDP) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
|
|
||||||
if m.GetAccountFunc != nil {
|
|
||||||
return m.GetAccountFunc(ctx, accountId)
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllAccounts is a mock implementation of the IDP interface GetAllAccounts method
|
|
||||||
func (m *MockIDP) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
|
||||||
if m.GetAllAccountsFunc != nil {
|
|
||||||
return m.GetAllAccountsFunc(ctx)
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateUser is a mock implementation of the IDP interface CreateUser method
|
// CreateUser is a mock implementation of the IDP interface CreateUser method
|
||||||
func (m *MockIDP) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
func (m *MockIDP) CreateUser(ctx context.Context, email, name string) (*UserData, error) {
|
||||||
if m.CreateUserFunc != nil {
|
if m.CreateUserFunc != nil {
|
||||||
return m.CreateUserFunc(ctx, email, name, accountID, invitedByEmail)
|
return m.CreateUserFunc(ctx, email, name)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
|
||||||
|
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string) (*UserData, error) {
|
||||||
|
if m.GetUserDataByIDFunc != nil {
|
||||||
|
return m.GetUserDataByIDFunc(ctx, userId)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -62,6 +36,14 @@ func (m *MockIDP) GetUserByEmail(ctx context.Context, email string) ([]*UserData
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAllUsers is a mock implementation of the IDP interface GetAllUsers method
|
||||||
|
func (m *MockIDP) GetAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||||
|
if m.GetAllUsersFunc != nil {
|
||||||
|
return m.GetAllUsersFunc(ctx)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// InviteUserByID is a mock implementation of the IDP interface InviteUserByID method
|
// InviteUserByID is a mock implementation of the IDP interface InviteUserByID method
|
||||||
func (m *MockIDP) InviteUserByID(ctx context.Context, userID string) error {
|
func (m *MockIDP) InviteUserByID(ctx context.Context, userID string) error {
|
||||||
if m.InviteUserByIDFunc != nil {
|
if m.InviteUserByIDFunc != nil {
|
||||||
|
|||||||
@@ -1,306 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/okta/okta-sdk-golang/v2/okta"
|
|
||||||
"github.com/okta/okta-sdk-golang/v2/okta/query"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OktaManager okta manager client instance.
|
|
||||||
type OktaManager struct {
|
|
||||||
client *okta.Client
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
credentials ManagerCredentials
|
|
||||||
helper ManagerHelper
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// OktaClientConfig okta manager client configurations.
|
|
||||||
type OktaClientConfig struct {
|
|
||||||
APIToken string
|
|
||||||
Issuer string
|
|
||||||
TokenEndpoint string
|
|
||||||
GrantType string
|
|
||||||
}
|
|
||||||
|
|
||||||
// OktaCredentials okta authentication information.
|
|
||||||
type OktaCredentials struct {
|
|
||||||
clientConfig OktaClientConfig
|
|
||||||
helper ManagerHelper
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewOktaManager creates a new instance of the OktaManager.
|
|
||||||
func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*OktaManager, error) {
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
httpTransport.MaxIdleConns = 5
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
Transport: httpTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
helper := JsonParser{}
|
|
||||||
config.Issuer = baseURL(config.Issuer)
|
|
||||||
|
|
||||||
if config.APIToken == "" {
|
|
||||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, APIToken is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Issuer == "" {
|
|
||||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, Issuer is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, TokenEndpoint is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.GrantType == "" {
|
|
||||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, GrantType is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
_, client, err := okta.NewClient(context.Background(),
|
|
||||||
okta.WithOrgUrl(config.Issuer),
|
|
||||||
okta.WithToken(config.APIToken),
|
|
||||||
okta.WithHttpClientPtr(httpClient),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
credentials := &OktaCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: httpClient,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &OktaManager{
|
|
||||||
client: client,
|
|
||||||
httpClient: httpClient,
|
|
||||||
credentials: credentials,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authenticate retrieves access token to use the okta user API.
|
|
||||||
func (oc *OktaCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
|
||||||
return JWTToken{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateUser creates a new user in okta Idp and sends an invitation.
|
|
||||||
func (om *OktaManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
|
||||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserDataByID requests user data from keycloak via ID.
|
|
||||||
func (om *OktaManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
|
||||||
user, resp, err := om.client.User.GetUser(context.Background(), userID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if om.appMetrics != nil {
|
|
||||||
om.appMetrics.IDPMetrics().CountGetUserDataByID()
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if om.appMetrics != nil {
|
|
||||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
userData, err := parseOktaUser(user)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
userData.AppMetadata = appMetadata
|
|
||||||
|
|
||||||
return userData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserByEmail searches users with a given email.
|
|
||||||
// If no users have been found, this function returns an empty list.
|
|
||||||
func (om *OktaManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
|
||||||
user, resp, err := om.client.User.GetUser(context.Background(), url.QueryEscape(email))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if om.appMetrics != nil {
|
|
||||||
om.appMetrics.IDPMetrics().CountGetUserByEmail()
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if om.appMetrics != nil {
|
|
||||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
userData, err := parseOktaUser(user)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
users := make([]*UserData, 0)
|
|
||||||
users = append(users, userData)
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccount returns all the users for a given profile.
|
|
||||||
func (om *OktaManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
|
||||||
users, err := om.getAllUsers()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if om.appMetrics != nil {
|
|
||||||
om.appMetrics.IDPMetrics().CountGetAccount()
|
|
||||||
}
|
|
||||||
|
|
||||||
for index, user := range users {
|
|
||||||
user.AppMetadata.WTAccountID = accountID
|
|
||||||
users[index] = user
|
|
||||||
}
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
|
||||||
// It returns a list of users indexed by accountID.
|
|
||||||
func (om *OktaManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
|
||||||
users, err := om.getAllUsers()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
indexedUsers := make(map[string][]*UserData)
|
|
||||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
|
||||||
|
|
||||||
if om.appMetrics != nil {
|
|
||||||
om.appMetrics.IDPMetrics().CountGetAllAccounts()
|
|
||||||
}
|
|
||||||
|
|
||||||
return indexedUsers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getAllUsers returns all users in an Okta account.
|
|
||||||
func (om *OktaManager) getAllUsers() ([]*UserData, error) {
|
|
||||||
qp := query.NewQueryParams(query.WithLimit(200))
|
|
||||||
userList, resp, err := om.client.User.ListUsers(context.Background(), qp)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if om.appMetrics != nil {
|
|
||||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
for resp.HasNextPage() {
|
|
||||||
paginatedUsers := make([]*okta.User, 0)
|
|
||||||
resp, err = resp.Next(context.Background(), &paginatedUsers)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if om.appMetrics != nil {
|
|
||||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
userList = append(userList, paginatedUsers...)
|
|
||||||
}
|
|
||||||
|
|
||||||
users := make([]*UserData, 0, len(userList))
|
|
||||||
for _, user := range userList {
|
|
||||||
userData, err := parseOktaUser(user)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
users = append(users, userData)
|
|
||||||
}
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
|
||||||
func (om *OktaManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InviteUserByID resend invitations to users who haven't activated,
|
|
||||||
// their accounts prior to the expiration period.
|
|
||||||
func (om *OktaManager) InviteUserByID(_ context.Context, _ string) error {
|
|
||||||
return fmt.Errorf("method InviteUserByID not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteUser from Okta
|
|
||||||
func (om *OktaManager) DeleteUser(_ context.Context, userID string) error {
|
|
||||||
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if om.appMetrics != nil {
|
|
||||||
om.appMetrics.IDPMetrics().CountDeleteUser()
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
if om.appMetrics != nil {
|
|
||||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseOktaUser parse okta user to UserData.
|
|
||||||
func parseOktaUser(user *okta.User) (*UserData, error) {
|
|
||||||
var oktaUser struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
FirstName string `json:"firstName"`
|
|
||||||
LastName string `json:"lastName"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if user == nil {
|
|
||||||
return nil, fmt.Errorf("invalid okta user")
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.Profile != nil {
|
|
||||||
helper := JsonParser{}
|
|
||||||
buf, err := helper.Marshal(*user.Profile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = helper.Unmarshal(buf, &oktaUser)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &UserData{
|
|
||||||
Email: oktaUser.Email,
|
|
||||||
Name: strings.Join([]string{oktaUser.FirstName, oktaUser.LastName}, " "),
|
|
||||||
ID: user.Id,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/okta/okta-sdk-golang/v2/okta"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParseOktaUser(t *testing.T) {
|
|
||||||
type parseOktaUserTest struct {
|
|
||||||
name string
|
|
||||||
inputProfile *okta.User
|
|
||||||
expectedUserData *UserData
|
|
||||||
assertErrFunc assert.ErrorAssertionFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
parseOktaTestCase1 := parseOktaUserTest{
|
|
||||||
name: "Good Request",
|
|
||||||
inputProfile: &okta.User{
|
|
||||||
Id: "123",
|
|
||||||
Profile: &okta.UserProfile{
|
|
||||||
"email": "test@example.com",
|
|
||||||
"firstName": "John",
|
|
||||||
"lastName": "Doe",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedUserData: &UserData{
|
|
||||||
Email: "test@example.com",
|
|
||||||
Name: "John Doe",
|
|
||||||
ID: "123",
|
|
||||||
AppMetadata: AppMetadata{
|
|
||||||
WTAccountID: "456",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
assertErrFunc: assert.NoError,
|
|
||||||
}
|
|
||||||
|
|
||||||
parseOktaTestCase2 := parseOktaUserTest{
|
|
||||||
name: "Invalid okta user",
|
|
||||||
inputProfile: nil,
|
|
||||||
expectedUserData: nil,
|
|
||||||
assertErrFunc: assert.Error,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
userData, err := parseOktaUser(testCase.inputProfile)
|
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFunc)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
assert.True(t, userDataEqual(testCase.expectedUserData, userData), "user data should match")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// userDataEqual helper function to compare UserData structs for equality.
|
|
||||||
func userDataEqual(a, b *UserData) bool {
|
|
||||||
if a.Email != b.Email || a.Name != b.Name || a.ID != b.ID {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
@@ -1,384 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PocketIdManager struct {
|
|
||||||
managementEndpoint string
|
|
||||||
apiToken string
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
credentials ManagerCredentials
|
|
||||||
helper ManagerHelper
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
type pocketIdCustomClaimDto struct {
|
|
||||||
Key string `json:"key"`
|
|
||||||
Value string `json:"value"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type pocketIdUserDto struct {
|
|
||||||
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
|
|
||||||
Disabled bool `json:"disabled"`
|
|
||||||
DisplayName string `json:"displayName"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
FirstName string `json:"firstName"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
IsAdmin bool `json:"isAdmin"`
|
|
||||||
LastName string `json:"lastName"`
|
|
||||||
LdapID string `json:"ldapId"`
|
|
||||||
Locale string `json:"locale"`
|
|
||||||
UserGroups []pocketIdUserGroupDto `json:"userGroups"`
|
|
||||||
Username string `json:"username"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type pocketIdUserCreateDto struct {
|
|
||||||
Disabled bool `json:"disabled,omitempty"`
|
|
||||||
DisplayName string `json:"displayName"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
FirstName string `json:"firstName"`
|
|
||||||
IsAdmin bool `json:"isAdmin,omitempty"`
|
|
||||||
LastName string `json:"lastName,omitempty"`
|
|
||||||
Locale string `json:"locale,omitempty"`
|
|
||||||
Username string `json:"username"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type pocketIdPaginatedUserDto struct {
|
|
||||||
Data []pocketIdUserDto `json:"data"`
|
|
||||||
Pagination pocketIdPaginationDto `json:"pagination"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type pocketIdPaginationDto struct {
|
|
||||||
CurrentPage int `json:"currentPage"`
|
|
||||||
ItemsPerPage int `json:"itemsPerPage"`
|
|
||||||
TotalItems int `json:"totalItems"`
|
|
||||||
TotalPages int `json:"totalPages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pocketIdUserDto) userData() *UserData {
|
|
||||||
return &UserData{
|
|
||||||
Email: p.Email,
|
|
||||||
Name: p.DisplayName,
|
|
||||||
ID: p.ID,
|
|
||||||
AppMetadata: AppMetadata{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type pocketIdUserGroupDto struct {
|
|
||||||
CreatedAt string `json:"createdAt"`
|
|
||||||
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
|
|
||||||
FriendlyName string `json:"friendlyName"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
LdapID string `json:"ldapId"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) {
|
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
httpTransport.MaxIdleConns = 5
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
Transport: httpTransport,
|
|
||||||
}
|
|
||||||
helper := JsonParser{}
|
|
||||||
|
|
||||||
if config.ManagementEndpoint == "" {
|
|
||||||
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.APIToken == "" {
|
|
||||||
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
credentials := &PocketIdCredentials{
|
|
||||||
clientConfig: config,
|
|
||||||
httpClient: httpClient,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &PocketIdManager{
|
|
||||||
managementEndpoint: config.ManagementEndpoint,
|
|
||||||
apiToken: config.APIToken,
|
|
||||||
httpClient: httpClient,
|
|
||||||
credentials: credentials,
|
|
||||||
helper: helper,
|
|
||||||
appMetrics: appMetrics,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) {
|
|
||||||
var MethodsWithBody = []string{http.MethodPost, http.MethodPut}
|
|
||||||
if !slices.Contains(MethodsWithBody, method) && body != "" {
|
|
||||||
return nil, fmt.Errorf("Body provided to unsupported method: %s", method)
|
|
||||||
}
|
|
||||||
|
|
||||||
reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource)
|
|
||||||
if query != nil {
|
|
||||||
reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode())
|
|
||||||
}
|
|
||||||
var req *http.Request
|
|
||||||
var err error
|
|
||||||
if body != "" {
|
|
||||||
req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body))
|
|
||||||
} else {
|
|
||||||
req, err = http.NewRequestWithContext(ctx, method, reqURL, nil)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Add("X-API-KEY", p.apiToken)
|
|
||||||
|
|
||||||
if body != "" {
|
|
||||||
req.Header.Add("content-type", "application/json")
|
|
||||||
req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength))
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := p.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if p.appMetrics != nil {
|
|
||||||
p.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
|
||||||
if p.appMetrics != nil {
|
|
||||||
p.appMetrics.IDPMetrics().CountRequestStatusError()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return io.ReadAll(resp.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getAllUsersPaginated fetches all users from PocketID API using pagination
|
|
||||||
func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) {
|
|
||||||
var allUsers []pocketIdUserDto
|
|
||||||
currentPage := 1
|
|
||||||
|
|
||||||
for {
|
|
||||||
params := url.Values{}
|
|
||||||
// Copy existing search parameters
|
|
||||||
for key, values := range searchParams {
|
|
||||||
params[key] = values
|
|
||||||
}
|
|
||||||
|
|
||||||
params.Set("pagination[limit]", "100")
|
|
||||||
params.Set("pagination[page]", fmt.Sprintf("%d", currentPage))
|
|
||||||
|
|
||||||
body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var profiles pocketIdPaginatedUserDto
|
|
||||||
err = p.helper.Unmarshal(body, &profiles)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
allUsers = append(allUsers, profiles.Data...)
|
|
||||||
|
|
||||||
// Check if we've reached the last page
|
|
||||||
if currentPage >= profiles.Pagination.TotalPages {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
currentPage++
|
|
||||||
}
|
|
||||||
|
|
||||||
return allUsers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
|
|
||||||
body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.appMetrics != nil {
|
|
||||||
p.appMetrics.IDPMetrics().CountGetUserDataByID()
|
|
||||||
}
|
|
||||||
|
|
||||||
var user pocketIdUserDto
|
|
||||||
err = p.helper.Unmarshal(body, &user)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
userData := user.userData()
|
|
||||||
userData.AppMetadata = appMetadata
|
|
||||||
|
|
||||||
return userData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
|
|
||||||
// Get all users using pagination
|
|
||||||
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.appMetrics != nil {
|
|
||||||
p.appMetrics.IDPMetrics().CountGetAccount()
|
|
||||||
}
|
|
||||||
|
|
||||||
users := make([]*UserData, 0)
|
|
||||||
for _, profile := range allUsers {
|
|
||||||
userData := profile.userData()
|
|
||||||
userData.AppMetadata.WTAccountID = accountId
|
|
||||||
|
|
||||||
users = append(users, userData)
|
|
||||||
}
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
|
||||||
// Get all users using pagination
|
|
||||||
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.appMetrics != nil {
|
|
||||||
p.appMetrics.IDPMetrics().CountGetAllAccounts()
|
|
||||||
}
|
|
||||||
|
|
||||||
indexedUsers := make(map[string][]*UserData)
|
|
||||||
for _, profile := range allUsers {
|
|
||||||
userData := profile.userData()
|
|
||||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
|
||||||
}
|
|
||||||
|
|
||||||
return indexedUsers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
|
||||||
firstLast := strings.Split(name, " ")
|
|
||||||
|
|
||||||
createUser := pocketIdUserCreateDto{
|
|
||||||
Disabled: false,
|
|
||||||
DisplayName: name,
|
|
||||||
Email: email,
|
|
||||||
FirstName: firstLast[0],
|
|
||||||
LastName: firstLast[1],
|
|
||||||
Username: firstLast[0] + "." + firstLast[1],
|
|
||||||
}
|
|
||||||
payload, err := p.helper.Marshal(createUser)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var newUser pocketIdUserDto
|
|
||||||
err = p.helper.Unmarshal(body, &newUser)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.appMetrics != nil {
|
|
||||||
p.appMetrics.IDPMetrics().CountCreateUser()
|
|
||||||
}
|
|
||||||
var pending bool = true
|
|
||||||
ret := &UserData{
|
|
||||||
Email: email,
|
|
||||||
Name: name,
|
|
||||||
ID: newUser.ID,
|
|
||||||
AppMetadata: AppMetadata{
|
|
||||||
WTAccountID: accountID,
|
|
||||||
WTPendingInvite: &pending,
|
|
||||||
WTInvitedBy: invitedByEmail,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return ret, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
|
||||||
params := url.Values{
|
|
||||||
// This value a
|
|
||||||
"search": []string{email},
|
|
||||||
}
|
|
||||||
body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.appMetrics != nil {
|
|
||||||
p.appMetrics.IDPMetrics().CountGetUserByEmail()
|
|
||||||
}
|
|
||||||
|
|
||||||
var profiles struct{ data []pocketIdUserDto }
|
|
||||||
err = p.helper.Unmarshal(body, &profiles)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
users := make([]*UserData, 0)
|
|
||||||
for _, profile := range profiles.data {
|
|
||||||
users = append(users, profile.userData())
|
|
||||||
}
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error {
|
|
||||||
_, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error {
|
|
||||||
_, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.appMetrics != nil {
|
|
||||||
p.appMetrics.IDPMetrics().CountDeleteUser()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Manager = (*PocketIdManager)(nil)
|
|
||||||
|
|
||||||
type PocketIdClientConfig struct {
|
|
||||||
APIToken string
|
|
||||||
ManagementEndpoint string
|
|
||||||
}
|
|
||||||
|
|
||||||
type PocketIdCredentials struct {
|
|
||||||
clientConfig PocketIdClientConfig
|
|
||||||
helper ManagerHelper
|
|
||||||
httpClient ManagerHTTPClient
|
|
||||||
appMetrics telemetry.AppMetrics
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ ManagerCredentials = (*PocketIdCredentials)(nil)
|
|
||||||
|
|
||||||
func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
|
||||||
return JWTToken{}, nil
|
|
||||||
}
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
package idp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewPocketIdManager(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
name string
|
|
||||||
inputConfig PocketIdClientConfig
|
|
||||||
assertErrFunc require.ErrorAssertionFunc
|
|
||||||
assertErrFuncMessage string
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultTestConfig := PocketIdClientConfig{
|
|
||||||
APIToken: "api_token",
|
|
||||||
ManagementEndpoint: "http://localhost",
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []test{
|
|
||||||
{
|
|
||||||
name: "Good Configuration",
|
|
||||||
inputConfig: defaultTestConfig,
|
|
||||||
assertErrFunc: require.NoError,
|
|
||||||
assertErrFuncMessage: "shouldn't return error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Missing ManagementEndpoint",
|
|
||||||
inputConfig: PocketIdClientConfig{
|
|
||||||
APIToken: defaultTestConfig.APIToken,
|
|
||||||
ManagementEndpoint: "",
|
|
||||||
},
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Missing APIToken",
|
|
||||||
inputConfig: PocketIdClientConfig{
|
|
||||||
APIToken: "",
|
|
||||||
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
|
|
||||||
},
|
|
||||||
assertErrFunc: require.Error,
|
|
||||||
assertErrFuncMessage: "should return error when field empty",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
|
|
||||||
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPocketID_GetUserDataByID(t *testing.T) {
|
|
||||||
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
|
|
||||||
|
|
||||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
mgr.httpClient = client
|
|
||||||
|
|
||||||
md := AppMetadata{WTAccountID: "acc1"}
|
|
||||||
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "u1", got.ID)
|
|
||||||
assert.Equal(t, "user1@example.com", got.Email)
|
|
||||||
assert.Equal(t, "User One", got.Name)
|
|
||||||
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPocketID_GetAccount_WithPagination(t *testing.T) {
|
|
||||||
// Single page response with two users
|
|
||||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
|
||||||
|
|
||||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
mgr.httpClient = client
|
|
||||||
|
|
||||||
users, err := mgr.GetAccount(context.Background(), "accX")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, users, 2)
|
|
||||||
assert.Equal(t, "u1", users[0].ID)
|
|
||||||
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
|
|
||||||
assert.Equal(t, "u2", users[1].ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
|
|
||||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
|
||||||
|
|
||||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
mgr.httpClient = client
|
|
||||||
|
|
||||||
accounts, err := mgr.GetAllAccounts(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, accounts[UnsetAccountID], 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPocketID_CreateUser(t *testing.T) {
|
|
||||||
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
|
|
||||||
|
|
||||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
mgr.httpClient = client
|
|
||||||
|
|
||||||
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "newid", ud.ID)
|
|
||||||
assert.Equal(t, "new@example.com", ud.Email)
|
|
||||||
assert.Equal(t, "New User", ud.Name)
|
|
||||||
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
|
|
||||||
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
|
|
||||||
assert.True(t, *ud.AppMetadata.WTPendingInvite)
|
|
||||||
}
|
|
||||||
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPocketID_InviteAndDeleteUser(t *testing.T) {
|
|
||||||
// Same mock for both calls; returns OK with empty JSON
|
|
||||||
client := &mockHTTPClient{code: 200, resBody: `{}`}
|
|
||||||
|
|
||||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
mgr.httpClient = client
|
|
||||||
|
|
||||||
err = mgr.InviteUserByID(context.Background(), "u1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = mgr.DeleteUser(context.Background(), "u1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
47
management/server/idp/test_util_test.go
Normal file
47
management/server/idp/test_util_test.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package idp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockHTTPClient is a mock implementation of ManagerHTTPClient for testing
|
||||||
|
type mockHTTPClient struct {
|
||||||
|
code int
|
||||||
|
resBody string
|
||||||
|
reqBody string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||||
|
if c.err != nil {
|
||||||
|
return nil, c.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Body != nil {
|
||||||
|
body, _ := io.ReadAll(req.Body)
|
||||||
|
c.reqBody = string(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: c.code,
|
||||||
|
Body: io.NopCloser(bytes.NewReader([]byte(c.resBody))),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTestJWT creates a test JWT token with the given expiration time in seconds
|
||||||
|
func newTestJWT(t *testing.T, expiresIn int) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
|
||||||
|
exp := time.Now().Add(time.Duration(expiresIn) * time.Second).Unix()
|
||||||
|
payload := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf(`{"exp":%d}`, exp)))
|
||||||
|
signature := base64.RawURLEncoding.EncodeToString([]byte("test-signature"))
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||||
|
}
|
||||||
@@ -319,14 +319,19 @@ func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel.
|
// CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel.
|
||||||
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
// Note: Authorization data (account membership, invite status) is stored in NetBird's DB, not in the IdP.
|
||||||
|
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name string) (*UserData, error) {
|
||||||
firstLast := strings.SplitN(name, " ", 2)
|
firstLast := strings.SplitN(name, " ", 2)
|
||||||
|
lastName := firstLast[0]
|
||||||
|
if len(firstLast) > 1 {
|
||||||
|
lastName = firstLast[1]
|
||||||
|
}
|
||||||
|
|
||||||
var addUser = map[string]any{
|
var addUser = map[string]any{
|
||||||
"userName": email,
|
"userName": email,
|
||||||
"profile": map[string]string{
|
"profile": map[string]string{
|
||||||
"firstName": firstLast[0],
|
"firstName": firstLast[0],
|
||||||
"lastName": firstLast[0],
|
"lastName": lastName,
|
||||||
"displayName": name,
|
"displayName": name,
|
||||||
},
|
},
|
||||||
"email": map[string]any{
|
"email": map[string]any{
|
||||||
@@ -357,18 +362,11 @@ func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var pending bool = true
|
return &UserData{
|
||||||
ret := &UserData{
|
|
||||||
Email: email,
|
Email: email,
|
||||||
Name: name,
|
Name: name,
|
||||||
ID: newUser.UserId,
|
ID: newUser.UserId,
|
||||||
AppMetadata: AppMetadata{
|
}, nil
|
||||||
WTAccountID: accountID,
|
|
||||||
WTPendingInvite: &pending,
|
|
||||||
WTInvitedBy: invitedByEmail,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return ret, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserByEmail searches users with a given email.
|
// GetUserByEmail searches users with a given email.
|
||||||
@@ -413,7 +411,7 @@ func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUserDataByID requests user data from zitadel via ID.
|
// GetUserDataByID requests user data from zitadel via ID.
|
||||||
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string) (*UserData, error) {
|
||||||
body, err := zm.get(ctx, "users/"+userID, nil)
|
body, err := zm.get(ctx, "users/"+userID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -429,43 +427,12 @@ func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, ap
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userData := profile.User.userData()
|
return profile.User.userData(), nil
|
||||||
userData.AppMetadata = appMetadata
|
|
||||||
|
|
||||||
return userData, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccount returns all the users for a given profile.
|
// GetAllUsers returns all users from the IdP.
|
||||||
func (zm *ZitadelManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
// Used for cache warming - NetBird matches these against its own user database.
|
||||||
body, err := zm.post(ctx, "users/_search", "")
|
func (zm *ZitadelManager) GetAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if zm.appMetrics != nil {
|
|
||||||
zm.appMetrics.IDPMetrics().CountGetAccount()
|
|
||||||
}
|
|
||||||
|
|
||||||
var profiles struct{ Result []zitadelProfile }
|
|
||||||
err = zm.helper.Unmarshal(body, &profiles)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
users := make([]*UserData, 0)
|
|
||||||
for _, profile := range profiles.Result {
|
|
||||||
userData := profile.userData()
|
|
||||||
userData.AppMetadata.WTAccountID = accountID
|
|
||||||
|
|
||||||
users = append(users, userData)
|
|
||||||
}
|
|
||||||
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
|
||||||
// It returns a list of users indexed by accountID.
|
|
||||||
func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
|
||||||
body, err := zm.post(ctx, "users/_search", "")
|
body, err := zm.post(ctx, "users/_search", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -481,19 +448,12 @@ func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*Use
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
indexedUsers := make(map[string][]*UserData)
|
users := make([]*UserData, 0, len(profiles.Result))
|
||||||
for _, profile := range profiles.Result {
|
for _, profile := range profiles.Result {
|
||||||
userData := profile.userData()
|
users = append(users, profile.userData())
|
||||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return indexedUsers, nil
|
return users, nil
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
|
||||||
// Metadata values are base64 encoded.
|
|
||||||
func (zm *ZitadelManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type inviteUserRequest struct {
|
type inviteUserRequest struct {
|
||||||
|
|||||||
@@ -288,16 +288,14 @@ func TestZitadelAuthenticate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestZitadelProfile(t *testing.T) {
|
func TestZitadelProfile(t *testing.T) {
|
||||||
type azureProfileTest struct {
|
type zitadelProfileTest struct {
|
||||||
name string
|
name string
|
||||||
invite bool
|
|
||||||
inputProfile zitadelProfile
|
inputProfile zitadelProfile
|
||||||
expectedUserData UserData
|
expectedUserData UserData
|
||||||
}
|
}
|
||||||
|
|
||||||
azureProfileTestCase1 := azureProfileTest{
|
zitadelProfileTestCase1 := zitadelProfileTest{
|
||||||
name: "User Request",
|
name: "User Request",
|
||||||
invite: false,
|
|
||||||
inputProfile: zitadelProfile{
|
inputProfile: zitadelProfile{
|
||||||
ID: "test1",
|
ID: "test1",
|
||||||
State: "USER_STATE_ACTIVE",
|
State: "USER_STATE_ACTIVE",
|
||||||
@@ -322,15 +320,11 @@ func TestZitadelProfile(t *testing.T) {
|
|||||||
ID: "test1",
|
ID: "test1",
|
||||||
Name: "ZITADEL Admin",
|
Name: "ZITADEL Admin",
|
||||||
Email: "test1@mail.com",
|
Email: "test1@mail.com",
|
||||||
AppMetadata: AppMetadata{
|
|
||||||
WTAccountID: "1",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
azureProfileTestCase2 := azureProfileTest{
|
zitadelProfileTestCase2 := zitadelProfileTest{
|
||||||
name: "Service User Request",
|
name: "Service User Request",
|
||||||
invite: true,
|
|
||||||
inputProfile: zitadelProfile{
|
inputProfile: zitadelProfile{
|
||||||
ID: "test2",
|
ID: "test2",
|
||||||
State: "USER_STATE_ACTIVE",
|
State: "USER_STATE_ACTIVE",
|
||||||
@@ -345,15 +339,11 @@ func TestZitadelProfile(t *testing.T) {
|
|||||||
ID: "test2",
|
ID: "test2",
|
||||||
Name: "machine",
|
Name: "machine",
|
||||||
Email: "machine",
|
Email: "machine",
|
||||||
AppMetadata: AppMetadata{
|
|
||||||
WTAccountID: "1",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2} {
|
for _, testCase := range []zitadelProfileTest{zitadelProfileTestCase1, zitadelProfileTestCase2} {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
|
||||||
userData := testCase.inputProfile.userData()
|
userData := testCase.inputProfile.userData()
|
||||||
|
|
||||||
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
@@ -487,7 +486,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
|||||||
|
|
||||||
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
||||||
if am.idpManager != nil {
|
if am.idpManager != nil {
|
||||||
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
userdata, err := am.idpManager.GetUserDataByID(ctx, userID)
|
||||||
if err == nil && userdata != nil {
|
if err == nil && userdata != nil {
|
||||||
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
||||||
}
|
}
|
||||||
|
|||||||
9974
management/server/types/testdata/networkmap_golden.json
vendored
Normal file
9974
management/server/types/testdata/networkmap_golden.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9974
management/server/types/testdata/networkmap_golden_new.json
vendored
Normal file
9974
management/server/types/testdata/networkmap_golden_new.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9862
management/server/types/testdata/networkmap_golden_new_with_deleted_router.json
vendored
Normal file
9862
management/server/types/testdata/networkmap_golden_new_with_deleted_router.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
10086
management/server/types/testdata/networkmap_golden_new_with_onpeeradded.json
vendored
Normal file
10086
management/server/types/testdata/networkmap_golden_new_with_onpeeradded.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
10086
management/server/types/testdata/networkmap_golden_new_with_onpeeradded_router.json
vendored
Normal file
10086
management/server/types/testdata/networkmap_golden_new_with_onpeeradded_router.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9862
management/server/types/testdata/networkmap_golden_new_with_onpeerdeleted.json
vendored
Normal file
9862
management/server/types/testdata/networkmap_golden_new_with_onpeerdeleted.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9862
management/server/types/testdata/networkmap_golden_with_deleted_peer.json
vendored
Normal file
9862
management/server/types/testdata/networkmap_golden_with_deleted_peer.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
9862
management/server/types/testdata/networkmap_golden_with_deleted_router_peer.json
vendored
Normal file
9862
management/server/types/testdata/networkmap_golden_with_deleted_router_peer.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
10086
management/server/types/testdata/networkmap_golden_with_new_peer.json
vendored
Normal file
10086
management/server/types/testdata/networkmap_golden_with_new_peer.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
10086
management/server/types/testdata/networkmap_golden_with_new_router.json
vendored
Normal file
10086
management/server/types/testdata/networkmap_golden_with_new_router.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@@ -85,6 +85,8 @@ type User struct {
|
|||||||
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
|
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
|
||||||
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||||
Blocked bool
|
Blocked bool
|
||||||
|
// PendingInvite indicates whether the user has accepted their invite and logged in
|
||||||
|
PendingInvite bool
|
||||||
// PendingApproval indicates whether the user requires approval before being activated
|
// PendingApproval indicates whether the user requires approval before being activated
|
||||||
PendingApproval bool
|
PendingApproval bool
|
||||||
// LastLogin is the last time the user logged in to IdP
|
// LastLogin is the last time the user logged in to IdP
|
||||||
@@ -162,7 +164,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
userStatus := UserStatusActive
|
userStatus := UserStatusActive
|
||||||
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
|
if u.PendingInvite {
|
||||||
userStatus = UserStatusInvited
|
userStatus = UserStatusInvited
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,6 +201,7 @@ func (u *User) Copy() *User {
|
|||||||
ServiceUserName: u.ServiceUserName,
|
ServiceUserName: u.ServiceUserName,
|
||||||
PATs: pats,
|
PATs: pats,
|
||||||
Blocked: u.Blocked,
|
Blocked: u.Blocked,
|
||||||
|
PendingInvite: u.PendingInvite,
|
||||||
PendingApproval: u.PendingApproval,
|
PendingApproval: u.PendingApproval,
|
||||||
LastLogin: u.LastLogin,
|
LastLogin: u.LastLogin,
|
||||||
CreatedAt: u.CreatedAt,
|
CreatedAt: u.CreatedAt,
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
|||||||
Issued: invite.Issued,
|
Issued: invite.Issued,
|
||||||
IntegrationReference: invite.IntegrationReference,
|
IntegrationReference: invite.IntegrationReference,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
|
PendingInvite: true, // User hasn't accepted invite yet
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.Store.SaveUser(ctx, newUser); err != nil {
|
if err = am.Store.SaveUser(ctx, newUser); err != nil {
|
||||||
@@ -169,7 +170,7 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
|
|||||||
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
|
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviterUser.Email)
|
return am.idpManager.CreateUser(ctx, invite.Email, invite.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||||
@@ -285,8 +286,8 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
|
|||||||
return status.NewPermissionDeniedError()
|
return status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if the user is already registered with this ID
|
// Get user from NetBird's database (source of truth for authorization data)
|
||||||
user, err := am.lookupUserInCache(ctx, targetUserID, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -295,18 +296,17 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
|
|||||||
return status.Errorf(status.NotFound, "user account %s doesn't exist", targetUserID)
|
return status.Errorf(status.NotFound, "user account %s doesn't exist", targetUserID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if user account is already invited and account is not activated
|
// Check if user is still pending invite (hasn't activated their account)
|
||||||
pendingInvite := user.AppMetadata.WTPendingInvite
|
if !user.PendingInvite {
|
||||||
if pendingInvite == nil || !*pendingInvite {
|
|
||||||
return status.Errorf(status.PreconditionFailed, "can't invite a user with an activated NetBird account")
|
return status.Errorf(status.PreconditionFailed, "can't invite a user with an activated NetBird account")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.idpManager.InviteUserByID(ctx, user.ID)
|
err = am.idpManager.InviteUserByID(ctx, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, initiatorUserID, user.ID, accountID, activity.UserInvited, nil)
|
am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserInvited, nil)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1011,17 +1011,15 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
|||||||
|
|
||||||
func (am *DefaultAccountManager) deleteUserFromIDP(ctx context.Context, targetUserID, accountID string) error {
|
func (am *DefaultAccountManager) deleteUserFromIDP(ctx context.Context, targetUserID, accountID string) error {
|
||||||
if am.userDeleteFromIDPEnabled {
|
if am.userDeleteFromIDPEnabled {
|
||||||
log.WithContext(ctx).Debugf("user %s deleted from IdP", targetUserID)
|
log.WithContext(ctx).Debugf("deleting user %s from IdP", targetUserID)
|
||||||
err := am.idpManager.DeleteUser(ctx, targetUserID)
|
err := am.idpManager.DeleteUser(ctx, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err)
|
return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
err := am.idpManager.UpdateUserAppMetadata(ctx, targetUserID, idp.AppMetadata{})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to remove user %s app metadata in IdP: %s", targetUserID, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
// Note: If userDeleteFromIDPEnabled is false, the user remains in IdP but is removed from NetBird.
|
||||||
|
// This allows the user to re-authenticate if they're re-invited later.
|
||||||
|
|
||||||
err := am.removeUserFromCache(ctx, accountID, targetUserID)
|
err := am.removeUserFromCache(ctx, accountID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("remove user from account (%q) cache failed with error: %v", accountID, err)
|
log.WithContext(ctx).Errorf("remove user from account (%q) cache failed with error: %v", accountID, err)
|
||||||
@@ -1095,7 +1093,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
|||||||
if !isNil(am.idpManager) {
|
if !isNil(am.idpManager) {
|
||||||
// Delete if the user already exists in the IdP. Necessary in cases where a user account
|
// Delete if the user already exists in the IdP. Necessary in cases where a user account
|
||||||
// was created where a user account was provisioned but the user did not sign in
|
// was created where a user account was provisioned but the user did not sign in
|
||||||
_, err := am.idpManager.GetUserDataByID(ctx, targetUserInfo.ID, idp.AppMetadata{WTAccountID: accountID})
|
_, err := am.idpManager.GetUserDataByID(ctx, targetUserInfo.ID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = am.deleteUserFromIDP(ctx, targetUserInfo.ID, accountID)
|
err = am.deleteUserFromIDP(ctx, targetUserInfo.ID, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -563,7 +563,7 @@ func TestUser_InviteNewUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
idpMock := idp.MockIDP{
|
idpMock := idp.MockIDP{
|
||||||
CreateUserFunc: func(_ context.Context, email, name, accountID, invitedByEmail string) (*idp.UserData, error) {
|
CreateUserFunc: func(_ context.Context, email, name string) (*idp.UserData, error) {
|
||||||
newData := &idp.UserData{
|
newData := &idp.UserData{
|
||||||
Email: email,
|
Email: email,
|
||||||
Name: name,
|
Name: name,
|
||||||
@@ -574,7 +574,7 @@ func TestUser_InviteNewUser(t *testing.T) {
|
|||||||
|
|
||||||
return newData, nil
|
return newData, nil
|
||||||
},
|
},
|
||||||
GetAccountFunc: func(_ context.Context, accountId string) ([]*idp.UserData, error) {
|
GetAllUsersFunc: func(_ context.Context) ([]*idp.UserData, error) {
|
||||||
return mockData, nil
|
return mockData, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -1068,7 +1068,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
|||||||
am := DefaultAccountManager{
|
am := DefaultAccountManager{
|
||||||
Store: store,
|
Store: store,
|
||||||
eventStore: &activity.InMemoryEventStore{},
|
eventStore: &activity.InMemoryEventStore{},
|
||||||
idpManager: &idp.GoogleWorkspaceManager{}, // empty manager
|
idpManager: &idp.MockIDP{}, // empty manager
|
||||||
cacheLoading: map[string]chan struct{}{},
|
cacheLoading: map[string]chan struct{}{},
|
||||||
permissionsManager: permissionsManager,
|
permissionsManager: permissionsManager,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user