Compare commits

...

6 Commits

Author SHA1 Message Date
jnfrati
5ec45a4923 update action permissions 2026-04-28 17:54:16 +02:00
jnfrati
d5bf6451f4 feat: add comment automation on release workflow for PRs 2026-04-28 17:19:46 +02:00
Zoltan Papp
8fc4265995 [relay] evict foreign client cache on disconnect (#6015)
* [relay] evict foreign client cache on disconnect

When a foreign relay's TCP connection drops, the manager's
onServerDisconnected handler only triggered reconnect logic for the
home server; the disconnected foreign entry stayed in the relayClients
cache. Subsequent OpenConn calls reused the closed client until the
60-second cleanup tick evicted it, breaking peer connectivity through
that relay for up to a minute.

Evict the foreign entry from the cache on disconnect so the next
OpenConn dials a fresh client.

Also:
- Make the reconnect backoff cap configurable via WithMaxBackoffInterval
  ManagerOption; the previous hard-coded 60s constant forced
  TestAutoReconnect to sleep ~61s. Test now polls Ready() and finishes
  in ~2s.
- Add NB_HOME_RELAY_SERVERS env var that overrides the relay URL list
  received from management, so a peer can be pinned to a specific home
  relay (used by the netbird-conn-lab Edge 4 reproducer).

* [client] treat empty NB_HOME_RELAY_SERVERS as unset

Returning (urls=[], ok=true) when the env var contained only separators or
whitespace caused callers to wipe the mgmt-provided relay list, leaving the
peer with no relays. Treat a parsed-empty result the same as an unset env.
2026-04-28 15:04:48 +02:00
Zoltan Papp
9c50819f20 Don't mark management disconnected on transient job stream errors (#6005)
The JOB stream is a separate channel from the SYNC stream. Server-side
EOF or other transient errors on the JOB stream do not indicate that
the management connection is unhealthy — the SYNC stream remains the
authoritative state signal.

Previously, a JOB stream EOF would call notifyDisconnected and the
client would emit OnConnecting to the UI. The backoff retry would
reconnect the JOB stream, but handleJobStream never calls notifyConnected
on success, so the UI was stuck on "Connecting" until the next SYNC
event or health check.

Keep notifyDisconnected for codes.PermissionDenied since IsLoginRequired
relies on managementError to detect expired auth.
2026-04-28 15:04:41 +02:00
Bethuel Mmbaga
6f0eff3ba0 [management] Handle single-string JWT group claim from IdPs (#6014) 2026-04-28 14:48:28 +03:00
Bethuel Mmbaga
f8745723fc [management] Add Microsoft AD FS support for embedded Dex identity providers (#6008) 2026-04-28 12:42:19 +03:00
16 changed files with 299 additions and 39 deletions

View File

@@ -115,6 +115,12 @@ jobs:
release: release:
runs-on: ubuntu-latest-m runs-on: ubuntu-latest-m
outputs:
release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }}
linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }}
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
env: env:
flags: "" flags: ""
steps: steps:
@@ -213,10 +219,13 @@ jobs:
if: always() if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: Tag and push images (amd64 only) - name: Tag and push images (amd64 only)
id: tag_and_push_images
if: | if: |
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
(github.event_name == 'push' && github.ref == 'refs/heads/main') (github.event_name == 'push' && github.ref == 'refs/heads/main')
run: | run: |
set -euo pipefail
resolve_tags() { resolve_tags() {
if [[ "${{ github.event_name }}" == "pull_request" ]]; then if [[ "${{ github.event_name }}" == "pull_request" ]]; then
echo "pr-${{ github.event.pull_request.number }}" echo "pr-${{ github.event.pull_request.number }}"
@@ -225,6 +234,17 @@ jobs:
fi fi
} }
ghcr_package_url() {
local image="$1" package encoded_package
package="${image#ghcr.io/}"
package="${package#*/}"
package="${package%%:*}"
encoded_package="${package//\//%2F}"
echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}"
}
image_refs=()
tag_and_push() { tag_and_push() {
local src="$1" img_name tag dst local src="$1" img_name tag dst
img_name="${src%%:*}" img_name="${src%%:*}"
@@ -233,35 +253,56 @@ jobs:
echo "Tagging ${src} -> ${dst}" echo "Tagging ${src} -> ${dst}"
docker tag "$src" "$dst" docker tag "$src" "$dst"
docker push "$dst" docker push "$dst"
image_refs+=("$dst")
done done
} }
export -f tag_and_push resolve_tags cat > /tmp/goreleaser-artifacts.json <<'JSON'
${{ steps.goreleaser.outputs.artifacts }}
JSON
echo '${{ steps.goreleaser.outputs.artifacts }}' | \ mapfile -t src_images < <(
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \ jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
grep '^ghcr.io/' | while read -r SRC; do )
tag_and_push "$SRC"
done for src in "${src_images[@]}"; do
tag_and_push "$src"
done
{
echo "images_markdown<<EOF"
if [[ ${#image_refs[@]} -eq 0 ]]; then
echo "_No GHCR images were pushed._"
else
printf '%s\n' "${image_refs[@]}" | sort -u | while read -r image; do
printf -- '- [`%s`](%s)\n' "$image" "$(ghcr_package_url "$image")"
done
fi
echo "EOF"
} >> "$GITHUB_OUTPUT"
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
id: upload_release
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release name: release
path: dist/ path: dist/
retention-days: 7 retention-days: 7
- name: upload linux packages - name: upload linux packages
id: upload_linux_packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: linux-packages name: linux-packages
path: dist/netbird_linux** path: dist/netbird_linux**
retention-days: 7 retention-days: 7
- name: upload windows packages - name: upload windows packages
id: upload_windows_packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: windows-packages name: windows-packages
path: dist/netbird_windows** path: dist/netbird_windows**
retention-days: 7 retention-days: 7
- name: upload macos packages - name: upload macos packages
id: upload_macos_packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: macos-packages name: macos-packages
@@ -270,6 +311,8 @@ jobs:
release_ui: release_ui:
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs:
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
steps: steps:
- name: Parse semver string - name: Parse semver string
id: semver_parser id: semver_parser
@@ -360,6 +403,7 @@ jobs:
if: always() if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
id: upload_release_ui
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release-ui name: release-ui
@@ -368,6 +412,8 @@ jobs:
release_ui_darwin: release_ui_darwin:
runs-on: macos-latest runs-on: macos-latest
outputs:
release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }}
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
@@ -402,12 +448,110 @@ jobs:
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
id: upload_release_ui_darwin
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release-ui-darwin name: release-ui-darwin
path: dist/ path: dist/
retention-days: 3 retention-days: 3
comment_release_artifacts:
name: Comment release artifacts
runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin]
if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }}
permissions:
contents: read
issues: write
pull-requests: write
steps:
- name: Create or update PR comment
uses: actions/github-script@v7
env:
RELEASE_RESULT: ${{ needs.release.result }}
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }}
RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }}
LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }}
WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }}
MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }}
RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }}
RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }}
GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }}
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const marker = '<!-- netbird-release-artifacts -->';
const { owner, repo } = context.repo;
const issue_number = context.payload.pull_request.number;
const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`;
const shortSha = context.payload.pull_request.head.sha.slice(0, 7);
const artifactCell = (url, result) => {
if (url) return `[Download](${url})`;
return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_';
};
const artifacts = [
['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT],
['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT],
['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT],
];
const artifactRows = artifacts
.map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`)
.join('\n');
const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._';
const body = [
marker,
'## Release artifacts',
'',
`Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`,
'',
'| Artifact | Link |',
'| --- | --- |',
artifactRows,
'',
'### GHCR images (amd64)',
ghcrImages,
'',
'_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._',
].join('\n');
const comments = await github.paginate(github.rest.issues.listComments, {
owner,
repo,
issue_number,
per_page: 100,
});
const previous = comments.find(comment =>
comment.user?.type === 'Bot' && comment.body?.includes(marker)
);
if (previous) {
await github.rest.issues.updateComment({
owner,
repo,
comment_id: previous.id,
body,
});
core.info(`Updated release artifacts comment ${previous.id}`);
} else {
const { data } = await github.rest.issues.createComment({
owner,
repo,
issue_number,
body,
});
core.info(`Created release artifacts comment ${data.id}`);
}
trigger_signer: trigger_signer:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin] needs: [release, release_ui, release_ui_darwin]

View File

@@ -333,6 +333,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.statusRecorder.MarkSignalConnected() c.statusRecorder.MarkSignalConnected()
relayURLs, token := parseRelayInfo(loginResp) relayURLs, token := parseRelayInfo(loginResp)
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
relayURLs = override
}
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath) engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)

View File

@@ -944,7 +944,12 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
return fmt.Errorf("update relay token: %w", err) return fmt.Errorf("update relay token: %w", err)
} }
e.relayManager.UpdateServerURLs(update.Urls) urls := update.Urls
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
urls = override
}
e.relayManager.UpdateServerURLs(urls)
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled. // Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries. // We can ignore all errors because the guard will manage the reconnection retries.

View File

@@ -7,7 +7,8 @@ import (
) )
const ( const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY" EnvKeyNBForceRelay = "NB_FORCE_RELAY"
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
) )
func IsForceRelayed() bool { func IsForceRelayed() bool {
@@ -16,3 +17,28 @@ func IsForceRelayed() bool {
} }
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
} }
// OverrideRelayURLs returns the relay server URL list set in
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
// the override is active. When the env var is unset, the boolean is false
// and the caller should keep the list received from the management server.
// Intended for lab/debug scenarios where a peer must pin to a specific home
// relay regardless of what management offers.
func OverrideRelayURLs() ([]string, bool) {
raw := os.Getenv(EnvKeyNBHomeRelayServers)
if raw == "" {
return nil, false
}
parts := strings.Split(raw, ",")
urls := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
urls = append(urls, p)
}
}
if len(urls) == 0 {
return nil, false
}
return urls, true
}

View File

@@ -193,7 +193,7 @@ func (c *Connector) ToStorageConnector() (storage.Connector, error) {
// are stored with types that Dex can open. // are stored with types that Dex can open.
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) { func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
switch connType { switch connType {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak": case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
return "oidc", applyOIDCDefaults(connType, config) return "oidc", applyOIDCDefaults(connType, config)
default: default:
return connType, config return connType, config
@@ -218,6 +218,8 @@ func applyOIDCDefaults(connType string, config map[string]interface{}) map[strin
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"}) setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
case "okta", "pocketid": case "okta", "pocketid":
augmented["scopes"] = []string{"openid", "profile", "email", "groups"} augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
case "adfs":
augmented["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
} }
return augmented return augmented

View File

@@ -168,7 +168,7 @@ func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connecto
var err error var err error
switch cfg.Type { switch cfg.Type {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak": case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
dexType = "oidc" dexType = "oidc"
configData, err = buildOIDCConnectorConfig(cfg, redirectURI) configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
case "google": case "google":
@@ -220,6 +220,8 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid": case "pocketid":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "adfs":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
} }
return encodeConnectorConfig(oidcConfig) return encodeConnectorConfig(oidcConfig)
} }
@@ -283,7 +285,7 @@ func inferIdentityProviderType(dexType, connectorID string, _ map[string]interfa
// inferOIDCProviderType infers the specific OIDC provider from connector ID // inferOIDCProviderType infers the specific OIDC provider from connector ID
func inferOIDCProviderType(connectorID string) string { func inferOIDCProviderType(connectorID string) string {
connectorIDLower := strings.ToLower(connectorID) connectorIDLower := strings.ToLower(connectorID)
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} { for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak", "adfs"} {
if strings.Contains(connectorIDLower, provider) { if strings.Contains(connectorIDLower, provider) {
return provider return provider
} }

View File

@@ -274,7 +274,7 @@ func identityProviderToConnectorConfig(idpConfig *types.IdentityProvider) *dex.C
} }
// generateIdentityProviderID generates a unique ID for an identity provider. // generateIdentityProviderID generates a unique ID for an identity provider.
// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft), // For specific provider types (okta, zitadel, entra, google, pocketid, microsoft, adfs),
// the ID is prefixed with the type name. Generic OIDC providers get no prefix. // the ID is prefixed with the type name. Generic OIDC providers get no prefix.
func generateIdentityProviderID(idpType types.IdentityProviderType) string { func generateIdentityProviderID(idpType types.IdentityProviderType) string {
id := xid.New().String() id := xid.New().String()
@@ -296,6 +296,8 @@ func generateIdentityProviderID(idpType types.IdentityProviderType) string {
return "authentik-" + id return "authentik-" + id
case types.IdentityProviderTypeKeycloak: case types.IdentityProviderTypeKeycloak:
return "keycloak-" + id return "keycloak-" + id
case types.IdentityProviderTypeADFS:
return "adfs-" + id
default: default:
// Generic OIDC - no prefix // Generic OIDC - no prefix
return id return id

View File

@@ -39,6 +39,8 @@ const (
IdentityProviderTypeAuthentik IdentityProviderType = "authentik" IdentityProviderTypeAuthentik IdentityProviderType = "authentik"
// IdentityProviderTypeKeycloak is the Keycloak identity provider // IdentityProviderTypeKeycloak is the Keycloak identity provider
IdentityProviderTypeKeycloak IdentityProviderType = "keycloak" IdentityProviderTypeKeycloak IdentityProviderType = "keycloak"
// IdentityProviderTypeADFS is the Microsoft AD FS identity provider
IdentityProviderTypeADFS IdentityProviderType = "adfs"
) )
// IdentityProvider represents an identity provider configuration // IdentityProvider represents an identity provider configuration
@@ -112,7 +114,8 @@ func (t IdentityProviderType) IsValid() bool {
switch t { switch t {
case IdentityProviderTypeOIDC, IdentityProviderTypeZitadel, IdentityProviderTypeEntra, case IdentityProviderTypeOIDC, IdentityProviderTypeZitadel, IdentityProviderTypeEntra,
IdentityProviderTypeGoogle, IdentityProviderTypeOkta, IdentityProviderTypePocketID, IdentityProviderTypeGoogle, IdentityProviderTypeOkta, IdentityProviderTypePocketID,
IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak: IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak,
IdentityProviderTypeADFS:
return true return true
} }
return false return false

View File

@@ -146,7 +146,11 @@ func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string
userJWTGroups := make([]string, 0) userJWTGroups := make([]string, 0)
if claim, ok := claims[claimName]; ok { if claim, ok := claims[claimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok { switch claimGroups := claim.(type) {
case string:
// Some IdPs emit a single group claim as a string instead of an array.
userJWTGroups = append(userJWTGroups, claimGroups)
case []any:
for _, g := range claimGroups { for _, g := range claimGroups {
if group, ok := g.(string); ok { if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group) userJWTGroups = append(userJWTGroups, group)
@@ -154,9 +158,11 @@ func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string
log.Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g) log.Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
} }
} }
default:
log.Debugf("JWT claim %q is not a string or string array (type: %T): %v", claimName, claim, claim)
} }
} else { } else {
log.Debugf("JWT claim %q is not a string array", claimName) log.Debugf("JWT claim %q is missing", claimName)
} }
return userJWTGroups return userJWTGroups

View File

@@ -249,6 +249,15 @@ func TestClaimsExtractor_ToGroups(t *testing.T) {
groupClaimName: "groups", groupClaimName: "groups",
expectedGroups: []string{}, expectedGroups: []string{},
}, },
{
name: "extracts single group string from claim",
claims: jwt.MapClaims{
"sub": "user-123",
"groups": "admin",
},
groupClaimName: "groups",
expectedGroups: []string{"admin"},
},
{ {
name: "handles custom claim name", name: "handles custom claim name",
claims: jwt.MapClaims{ claims: jwt.MapClaims{

View File

@@ -252,21 +252,19 @@ func (c *GrpcClient) handleJobStream(
c.notifyDisconnected(err) c.notifyDisconnected(err)
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled: case codes.Canceled:
log.Debugf("management connection context has been canceled, this usually indicates shutdown") log.Debugf("job stream context has been canceled, this usually indicates shutdown")
return err return err
case codes.Unimplemented: case codes.Unimplemented:
log.Warn("Job feature is not supported by the current management server version. " + log.Warn("Job feature is not supported by the current management server version. " +
"Please update the management service to use this feature.") "Please update the management service to use this feature.")
return nil return nil
default: default:
c.notifyDisconnected(err) log.Warnf("job stream disconnected, will retry silently. Reason: %v", err)
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err return err
} }
} else { } else {
// non-gRPC error // non-gRPC error
c.notifyDisconnected(err) log.Warnf("job stream disconnected, will retry silently. Reason: %v", err)
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err return err
} }
} }

View File

@@ -2917,6 +2917,7 @@ components:
- okta - okta
- pocketid - pocketid
- microsoft - microsoft
- adfs
example: oidc example: oidc
IdentityProvider: IdentityProvider:
type: object type: object

View File

@@ -518,6 +518,7 @@ const (
IdentityProviderTypeOkta IdentityProviderType = "okta" IdentityProviderTypeOkta IdentityProviderType = "okta"
IdentityProviderTypePocketid IdentityProviderType = "pocketid" IdentityProviderTypePocketid IdentityProviderType = "pocketid"
IdentityProviderTypeZitadel IdentityProviderType = "zitadel" IdentityProviderTypeZitadel IdentityProviderType = "zitadel"
IdentityProviderTypeAdfs IdentityProviderType = "adfs"
) )
// Valid indicates whether the value is a known member of the IdentityProviderType enum. // Valid indicates whether the value is a known member of the IdentityProviderType enum.
@@ -537,6 +538,8 @@ func (e IdentityProviderType) Valid() bool {
return true return true
case IdentityProviderTypeZitadel: case IdentityProviderTypeZitadel:
return true return true
case IdentityProviderTypeAdfs:
return true
default: default:
return false return false
} }

View File

@@ -8,10 +8,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const ( const defaultMaxBackoffInterval = 60 * time.Second
// TODO: make it configurable, the manager should validate all configurable parameters
reconnectingTimeout = 60 * time.Second
)
// Guard manage the reconnection tries to the Relay server in case of disconnection event. // Guard manage the reconnection tries to the Relay server in case of disconnection event.
type Guard struct { type Guard struct {
@@ -19,14 +16,23 @@ type Guard struct {
OnNewRelayClient chan *Client OnNewRelayClient chan *Client
OnReconnected chan struct{} OnReconnected chan struct{}
serverPicker *ServerPicker serverPicker *ServerPicker
// maxBackoffInterval caps the exponential backoff between reconnect
// attempts.
maxBackoffInterval time.Duration
} }
// NewGuard creates a new guard for the relay client. // NewGuard creates a new guard for the relay client. A non-positive
func NewGuard(sp *ServerPicker) *Guard { // maxBackoffInterval falls back to defaultMaxBackoffInterval.
func NewGuard(sp *ServerPicker, maxBackoffInterval time.Duration) *Guard {
if maxBackoffInterval <= 0 {
maxBackoffInterval = defaultMaxBackoffInterval
}
g := &Guard{ g := &Guard{
OnNewRelayClient: make(chan *Client, 1), OnNewRelayClient: make(chan *Client, 1),
OnReconnected: make(chan struct{}, 1), OnReconnected: make(chan struct{}, 1),
serverPicker: sp, serverPicker: sp,
maxBackoffInterval: maxBackoffInterval,
} }
return g return g
} }
@@ -49,7 +55,7 @@ func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
} }
// start a ticker to pick a new server // start a ticker to pick a new server
ticker := exponentTicker(ctx) ticker := g.exponentTicker(ctx)
defer ticker.Stop() defer ticker.Stop()
for { for {
@@ -125,11 +131,11 @@ func (g *Guard) notifyReconnected() {
} }
} }
func exponentTicker(ctx context.Context) *backoff.Ticker { func (g *Guard) exponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{ bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 2 * time.Second, InitialInterval: 2 * time.Second,
Multiplier: 2, Multiplier: 2,
MaxInterval: reconnectingTimeout, MaxInterval: g.maxBackoffInterval,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
}, ctx) }, ctx)

View File

@@ -39,6 +39,15 @@ func NewRelayTrack() *RelayTrack {
type OnServerCloseListener func() type OnServerCloseListener func()
// ManagerOption configures a Manager at construction time.
type ManagerOption func(*Manager)
// WithMaxBackoffInterval caps the exponential backoff between reconnect
// attempts to the home relay. A non-positive value keeps the default.
func WithMaxBackoffInterval(d time.Duration) ManagerOption {
return func(m *Manager) { m.maxBackoffInterval = d }
}
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL // Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
// and automatically reconnect to them in case disconnection. // and automatically reconnect to them in case disconnection.
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a // The manager also manage temporary relay connection. If a client wants to communicate with a client on a
@@ -64,12 +73,13 @@ type Manager struct {
onReconnectedListenerFn func() onReconnectedListenerFn func()
listenerLock sync.Mutex listenerLock sync.Mutex
mtu uint16 mtu uint16
maxBackoffInterval time.Duration
} }
// NewManager creates a new manager instance. // NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve. // The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16) *Manager { func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager {
tokenStore := &relayAuth.TokenStore{} tokenStore := &relayAuth.TokenStore{}
m := &Manager{ m := &Manager{
@@ -86,8 +96,11 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin
relayClients: make(map[string]*RelayTrack), relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List), onDisconnectedListeners: make(map[string]*list.List),
} }
for _, opt := range opts {
opt(m)
}
m.serverPicker.ServerURLs.Store(serverURLs) m.serverPicker.ServerURLs.Store(serverURLs)
m.reconnectGuard = NewGuard(m.serverPicker) m.reconnectGuard = NewGuard(m.serverPicker, m.maxBackoffInterval)
return m return m
} }
@@ -290,19 +303,36 @@ func (m *Manager) onServerConnected() {
go m.onReconnectedListenerFn() go m.onReconnectedListenerFn()
} }
// onServerDisconnected start to reconnection for home server only // onServerDisconnected handles relay disconnect events. For the home server it
// starts the reconnect guard. For foreign servers it evicts the now-dead client
// from the cache so the next OpenConn builds a fresh one instead of reusing a
// closed client.
func (m *Manager) onServerDisconnected(serverAddress string) { func (m *Manager) onServerDisconnected(serverAddress string) {
m.relayClientMu.Lock() m.relayClientMu.Lock()
if serverAddress == m.relayClient.connectionURL { isHome := m.relayClient != nil && serverAddress == m.relayClient.connectionURL
if isHome {
go func(client *Client) { go func(client *Client) {
m.reconnectGuard.StartReconnectTrys(m.ctx, client) m.reconnectGuard.StartReconnectTrys(m.ctx, client)
}(m.relayClient) }(m.relayClient)
} }
m.relayClientMu.Unlock() m.relayClientMu.Unlock()
if !isHome {
m.evictForeignRelay(serverAddress)
}
m.notifyOnDisconnectListeners(serverAddress) m.notifyOnDisconnectListeners(serverAddress)
} }
func (m *Manager) evictForeignRelay(serverAddress string) {
m.relayClientsMutex.Lock()
defer m.relayClientsMutex.Unlock()
if _, ok := m.relayClients[serverAddress]; ok {
delete(m.relayClients, serverAddress)
log.Debugf("evicted disconnected foreign relay client: %s", serverAddress)
}
}
func (m *Manager) listenGuardEvent(ctx context.Context) { func (m *Manager) listenGuardEvent(ctx context.Context) {
for { for {
select { select {

View File

@@ -2,6 +2,7 @@ package client
import ( import (
"context" "context"
"fmt"
"testing" "testing"
"time" "time"
@@ -360,7 +361,8 @@ func TestAutoReconnect(t *testing.T) {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU) clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU,
WithMaxBackoffInterval(2*time.Second))
err = clientAlice.Serve() err = clientAlice.Serve()
if err != nil { if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
@@ -384,7 +386,9 @@ func TestAutoReconnect(t *testing.T) {
} }
log.Infof("waiting for reconnection") log.Infof("waiting for reconnection")
time.Sleep(reconnectingTimeout + 1*time.Second) if err := waitForReady(ctx, clientAlice, 15*time.Second); err != nil {
t.Fatalf("manager did not reconnect: %s", err)
}
log.Infof("reopent the connection") log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ctx, ra, "bob") _, err = clientAlice.OpenConn(ctx, ra, "bob")
@@ -393,6 +397,21 @@ func TestAutoReconnect(t *testing.T) {
} }
} }
func waitForReady(ctx context.Context, m *Manager, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if m.Ready() {
return nil
}
select {
case <-time.After(100 * time.Millisecond):
case <-ctx.Done():
return ctx.Err()
}
}
return fmt.Errorf("manager not ready within %s", timeout)
}
func TestNotifierDoubleAdd(t *testing.T) { func TestNotifierDoubleAdd(t *testing.T) {
ctx := context.Background() ctx := context.Background()